From b8fb19cfe8f7b8a0bd5d0f2b6981cd61465ace74 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 12 Jan 2026 11:47:01 -0800 Subject: [PATCH 1/3] Address rmm setup in bootstrap bootstrap_dask_cluster --- .../rapidsmpf/rapidsmpf/integrations/core.py | 133 +++++++++++++++++- python/rapidsmpf/rapidsmpf/tests/test_dask.py | 133 +++++++++++++++++- 2 files changed, 261 insertions(+), 5 deletions(-) diff --git a/python/rapidsmpf/rapidsmpf/integrations/core.py b/python/rapidsmpf/rapidsmpf/integrations/core.py index 1bd4b0d78..7a6c857b3 100644 --- a/python/rapidsmpf/rapidsmpf/integrations/core.py +++ b/python/rapidsmpf/rapidsmpf/integrations/core.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 """Shuffler integration with external libraries.""" @@ -670,6 +670,128 @@ def spill_func( return ctx.spill_collection.spill(amount, stream=DEFAULT_STREAM, device_mr=mr) +def setup_rmm_pool( + option_prefix: str, + options: Options, + worker: Any = None, +) -> rmm.mr.DeviceMemoryResource: + """ + Set up an RMM memory pool based on configuration options. + + This function mirrors the logic in dask-cuda's RMMSetup plugin to ensure + consistent behavior when rapidsmpf manages RMM instead of Dask or another + orchestrator. + + Parameters + ---------- + option_prefix + Prefix for config-option names (e.g., "dask_"). + options + Configuration options. + worker + Optional worker reference (for logging filenames). + + Returns + ------- + The configured RMM memory resource, or the current device resource + if no RMM options are specified. + + Notes + ----- + If no RMM configuration options are provided, this function returns + the current device resource without modification. + """ + # Parse RMM configuration options + # Note: Use Optional(None) as default for byte-size options. OptionalBytes can't + # be used because it calls parse_bytes() before checking for disabled keywords. + # When the user provides a value, we parse it with OptionalBytes. + initial_pool_size_opt = options.get_or_default( + f"{option_prefix}rmm_pool_size", default_value=Optional(None) + ).value + initial_pool_size = ( + OptionalBytes(initial_pool_size_opt).value + if initial_pool_size_opt is not None + else None + ) + maximum_pool_size_opt = options.get_or_default( + f"{option_prefix}rmm_maximum_pool_size", default_value=Optional(None) + ).value + maximum_pool_size = ( + OptionalBytes(maximum_pool_size_opt).value + if maximum_pool_size_opt is not None + else None + ) + managed_memory = options.get_or_default( + f"{option_prefix}rmm_managed_memory", default_value=False + ) + async_alloc = options.get_or_default( + f"{option_prefix}rmm_async", default_value=False + ) + release_threshold_opt = options.get_or_default( + f"{option_prefix}rmm_release_threshold", default_value=Optional(None) + ).value + release_threshold = ( + OptionalBytes(release_threshold_opt).value + if release_threshold_opt is not None + else None + ) + track_allocations = options.get_or_default( + f"{option_prefix}rmm_track_allocations", default_value=False + ) + + # If no RMM options specified, return current resource unchanged + if ( + initial_pool_size is None + and not managed_memory + and not async_alloc + and not track_allocations + ): + return rmm.mr.get_current_device_resource() + + # Validation (same as dask-cuda) + if async_alloc and managed_memory: + raise ValueError("`rmm_managed_memory` is incompatible with `rmm_async`.") + if not async_alloc and release_threshold is not None: + raise ValueError("`rmm_release_threshold` requires `rmm_async`.") + + # Setup based on mode (mirrors dask-cuda's RMMSetup plugin) + if async_alloc: + # Async allocation path using CudaAsyncMemoryResource + mr = rmm.mr.CudaAsyncMemoryResource( + initial_pool_size=initial_pool_size or 0, + release_threshold=release_threshold or 0, + ) + if maximum_pool_size is not None: + mr = rmm.mr.LimitingResourceAdaptor(mr, allocation_limit=maximum_pool_size) + rmm.mr.set_current_device_resource(mr) + elif initial_pool_size is not None or managed_memory: + # Pool/managed allocation path + # Choose the upstream memory resource + if managed_memory: + upstream_mr = rmm.mr.ManagedMemoryResource() + else: + upstream_mr = rmm.mr.CudaMemoryResource() + + # Create pool if initial_pool_size is specified + if initial_pool_size is not None: + mr = rmm.mr.PoolMemoryResource( + upstream_mr=upstream_mr, + initial_pool_size=initial_pool_size, + maximum_pool_size=maximum_pool_size, + ) + else: + # Only managed memory, no pool + mr = upstream_mr + rmm.mr.set_current_device_resource(mr) + + # Optionally wrap with tracking adaptor + if track_allocations: + mr = rmm.mr.get_current_device_resource() + rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr)) + + return rmm.mr.get_current_device_resource() + + def rmpf_worker_setup( worker: Any, option_prefix: str, @@ -698,10 +820,13 @@ def rmpf_worker_setup( Warnings -------- - This function creates a new RMM memory pool, and - sets it as the current device resource. + This function may create a new RMM memory pool (if configured), + and sets the RMM resource adaptor as the current device resource. """ - # Insert RMM resource adaptor on top of the current RMM resource stack. + # Set up RMM pool if configured (otherwise uses existing resource) + setup_rmm_pool(option_prefix, options, worker) + + # Insert RMM resource adaptor on top of the (now configured) RMM resource stack. mr = RmmResourceAdaptor( upstream_mr=rmm.mr.get_current_device_resource(), fallback_mr=( diff --git a/python/rapidsmpf/rapidsmpf/tests/test_dask.py b/python/rapidsmpf/rapidsmpf/tests/test_dask.py index 3741edf0f..e0825192c 100644 --- a/python/rapidsmpf/rapidsmpf/tests/test_dask.py +++ b/python/rapidsmpf/rapidsmpf/tests/test_dask.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations @@ -9,6 +9,8 @@ import dask.dataframe as dd import pytest +import rmm.mr + import rapidsmpf.integrations.single from rapidsmpf.communicator import COMMUNICATORS from rapidsmpf.config import Options @@ -17,6 +19,7 @@ dask_cudf_join, dask_cudf_shuffle, ) +from rapidsmpf.integrations.core import setup_rmm_pool from rapidsmpf.integrations.dask.core import get_worker_context from rapidsmpf.integrations.dask.shuffler import ( clear_shuffle_statistics, @@ -79,6 +82,134 @@ def get_rank(dask_worker: Worker) -> int: assert set(result.values()) == set(range(len(cluster.workers))) +def assert_workers_contain_mr_type( + client: Client, resource_type: type[rmm.mr.MemoryResource] +) -> None: + """Assert that all workers contain the given resource type.""" + + def _has_rmm_resource_type() -> bool: + # The given resource type must be in the resource chain. + mr = rmm.mr.get_current_device_resource() + while mr is not None: + if isinstance(mr, resource_type): + return True + mr = getattr(mr, "upstream_mr", None) + return False + + result = client.run(_has_rmm_resource_type) + for worker_addr, has_type in result.items(): + assert has_type, ( + f"Worker {worker_addr} does not have {resource_type.__name__} in chain." + ) + + +@gen_test(timeout=30) +async def test_dask_rmm_setup_async() -> None: + """Test that rapidsmpf can set up RMM async pool during bootstrap.""" + with ( + LocalCUDACluster( + scheduler_port=0, + device_memory_limit=1, + # Don't let dask-cuda set up RMM - rapidsmpf will do it + rmm_pool_size=None, + rmm_async=False, + ) as cluster, + Client(cluster) as client, + ): + bootstrap_dask_cluster( + client, + options=Options( + { + "dask_spill_device": "0.1", + "dask_rmm_async": "true", + "dask_rmm_pool_size": "128 MiB", + } + ), + ) + assert_workers_contain_mr_type(client, rmm.mr.CudaAsyncMemoryResource) + + +@gen_test(timeout=30) +async def test_dask_rmm_setup_pool() -> None: + """Test that rapidsmpf can set up RMM pool (non-async) during bootstrap.""" + with ( + LocalCUDACluster( + scheduler_port=0, + device_memory_limit=1, + # Don't let dask-cuda set up RMM - rapidsmpf will do it + rmm_pool_size=None, + rmm_async=False, + ) as cluster, + Client(cluster) as client, + ): + bootstrap_dask_cluster( + client, + options=Options( + { + "dask_spill_device": "0.1", + "dask_rmm_async": "false", + "dask_rmm_pool_size": "128 MiB", + } + ), + ) + + assert_workers_contain_mr_type(client, rmm.mr.PoolMemoryResource) + + +@gen_test(timeout=30) +async def test_dask_rmm_setup_track_allocations() -> None: + """Test that rapidsmpf can enable RMM tracking during bootstrap.""" + with ( + LocalCUDACluster( + scheduler_port=0, + device_memory_limit=1, + rmm_pool_size=None, + rmm_async=False, + ) as cluster, + Client(cluster) as client, + ): + bootstrap_dask_cluster( + client, + options=Options( + { + "dask_spill_device": "0.1", + "dask_rmm_track_allocations": "true", + } + ), + ) + + assert_workers_contain_mr_type(client, rmm.mr.TrackingResourceAdaptor) + + +def test_rmm_setup_validation_errors() -> None: + """Test that invalid RMM option combinations raise errors.""" + + # rmm_managed_memory is incompatible with rmm_async + with pytest.raises(ValueError, match="incompatible"): + setup_rmm_pool( + "test_", + Options( + { + "test_rmm_async": "true", + "test_rmm_managed_memory": "true", + } + ), + ) + + # rmm_release_threshold requires rmm_async + with pytest.raises(ValueError, match="requires"): + setup_rmm_pool( + "test_", + Options( + { + "test_rmm_async": "false", + "test_rmm_pool_size": "128 MiB", + "test_rmm_release_threshold": "64 MiB", + } + ), + ) + + @pytest.mark.parametrize("partition_count", [None, 3]) @pytest.mark.parametrize("sort", [True, False]) def test_dask_cudf_integration( From 69827c888fb760e74ff611fe725ce9818ba5a8ba Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 12 Jan 2026 13:57:51 -0800 Subject: [PATCH 2/3] fix byte-parsing --- .../rapidsmpf/rapidsmpf/integrations/core.py | 73 ++++++++++++++----- python/rapidsmpf/rapidsmpf/tests/test_dask.py | 34 +++++++-- 2 files changed, 85 insertions(+), 22 deletions(-) diff --git a/python/rapidsmpf/rapidsmpf/integrations/core.py b/python/rapidsmpf/rapidsmpf/integrations/core.py index 7a6c857b3..ecd19ac03 100644 --- a/python/rapidsmpf/rapidsmpf/integrations/core.py +++ b/python/rapidsmpf/rapidsmpf/integrations/core.py @@ -670,6 +670,42 @@ def spill_func( return ctx.spill_collection.spill(amount, stream=DEFAULT_STREAM, device_mr=mr) +def _parse_pool_size( + value: float | None, total_device_memory: int, *, alignment_size: int = 256 +) -> int | None: + """ + Parse a pool size value, supporting fractions of total device memory. + + Parameters + ---------- + value + Can be: + - None: Returns None + - A float in (0, 1]: Interpreted as fraction of total device memory + - A float > 1: Interpreted as byte count + total_device_memory + Total device memory in bytes. + alignment_size + Byte alignment (RMM pools require 256-byte alignment). + + Returns + ------- + Parsed byte count aligned to alignment_size, or None. + """ + if value is None or value == 0: + return None + + if 0.0 < value <= 1.0: + # Fraction of device memory + byte_count = int(total_device_memory * value) + else: + # Already a byte count + byte_count = int(value) + + # Align to alignment_size + return (byte_count // alignment_size) * alignment_size + + def setup_rmm_pool( option_prefix: str, options: Options, @@ -701,40 +737,43 @@ def setup_rmm_pool( If no RMM configuration options are provided, this function returns the current device resource without modification. """ + # Get total device memory for parsing fractional pool sizes + total_device_memory = rmm.mr.available_device_memory()[1] + # Parse RMM configuration options - # Note: Use Optional(None) as default for byte-size options. OptionalBytes can't - # be used because it calls parse_bytes() before checking for disabled keywords. - # When the user provides a value, we parse it with OptionalBytes. + # Pool sizes follow the same pattern as spill_device: values in (0, 1] are + # fractions of device memory, values > 1 are byte counts. initial_pool_size_opt = options.get_or_default( f"{option_prefix}rmm_pool_size", default_value=Optional(None) ).value - initial_pool_size = ( - OptionalBytes(initial_pool_size_opt).value - if initial_pool_size_opt is not None - else None + initial_pool_size = _parse_pool_size( + float(initial_pool_size_opt) if initial_pool_size_opt is not None else None, + total_device_memory, ) + maximum_pool_size_opt = options.get_or_default( f"{option_prefix}rmm_maximum_pool_size", default_value=Optional(None) ).value - maximum_pool_size = ( - OptionalBytes(maximum_pool_size_opt).value - if maximum_pool_size_opt is not None - else None + maximum_pool_size = _parse_pool_size( + float(maximum_pool_size_opt) if maximum_pool_size_opt is not None else None, + total_device_memory, ) + managed_memory = options.get_or_default( f"{option_prefix}rmm_managed_memory", default_value=False ) async_alloc = options.get_or_default( f"{option_prefix}rmm_async", default_value=False ) + release_threshold_opt = options.get_or_default( f"{option_prefix}rmm_release_threshold", default_value=Optional(None) ).value - release_threshold = ( - OptionalBytes(release_threshold_opt).value - if release_threshold_opt is not None - else None + release_threshold = _parse_pool_size( + float(release_threshold_opt) if release_threshold_opt is not None else None, + total_device_memory, ) + track_allocations = options.get_or_default( f"{option_prefix}rmm_track_allocations", default_value=False ) @@ -758,8 +797,8 @@ def setup_rmm_pool( if async_alloc: # Async allocation path using CudaAsyncMemoryResource mr = rmm.mr.CudaAsyncMemoryResource( - initial_pool_size=initial_pool_size or 0, - release_threshold=release_threshold or 0, + initial_pool_size=initial_pool_size, + release_threshold=release_threshold, ) if maximum_pool_size is not None: mr = rmm.mr.LimitingResourceAdaptor(mr, allocation_limit=maximum_pool_size) diff --git a/python/rapidsmpf/rapidsmpf/tests/test_dask.py b/python/rapidsmpf/rapidsmpf/tests/test_dask.py index e0825192c..16599cc89 100644 --- a/python/rapidsmpf/rapidsmpf/tests/test_dask.py +++ b/python/rapidsmpf/rapidsmpf/tests/test_dask.py @@ -19,7 +19,7 @@ dask_cudf_join, dask_cudf_shuffle, ) -from rapidsmpf.integrations.core import setup_rmm_pool +from rapidsmpf.integrations.core import _parse_pool_size, setup_rmm_pool from rapidsmpf.integrations.dask.core import get_worker_context from rapidsmpf.integrations.dask.shuffler import ( clear_shuffle_statistics, @@ -122,7 +122,7 @@ async def test_dask_rmm_setup_async() -> None: { "dask_spill_device": "0.1", "dask_rmm_async": "true", - "dask_rmm_pool_size": "128 MiB", + "dask_rmm_pool_size": "0.1", # 10% of device memory } ), ) @@ -148,7 +148,7 @@ async def test_dask_rmm_setup_pool() -> None: { "dask_spill_device": "0.1", "dask_rmm_async": "false", - "dask_rmm_pool_size": "128 MiB", + "dask_rmm_pool_size": "0.1", # 10% of device memory } ), ) @@ -203,13 +203,37 @@ def test_rmm_setup_validation_errors() -> None: Options( { "test_rmm_async": "false", - "test_rmm_pool_size": "128 MiB", - "test_rmm_release_threshold": "64 MiB", + "test_rmm_pool_size": "0.1", # 10% of device memory + "test_rmm_release_threshold": "0.05", # 5% of device memory } ), ) +# Use a fixed total memory for deterministic tests +_TOTAL_MEMORY = 16 * 1024 * 1024 * 1024 # 16 GiB + + +@pytest.mark.parametrize( + "value,expected", + [ + # None and zero return None + (None, None), + (0, None), + (0.0, None), + # Fractions (0 < value <= 1) are % of device memory + (0.5, 8 * 1024 * 1024 * 1024), # 50% of 16 GiB = 8 GiB + (1.0, _TOTAL_MEMORY), # 100% of device memory + # Values > 1 are byte counts (aligned to 256) + (1024 * 1024 * 1024, 1024 * 1024 * 1024), # 1 GiB + (1000, 768), # floor(1000 / 256) * 256 + ], +) +def test_parse_pool_size(value: float | None, expected: int | None) -> None: + """Test _parse_pool_size with various inputs.""" + assert _parse_pool_size(value, _TOTAL_MEMORY) == expected + + @pytest.mark.parametrize("partition_count", [None, 3]) @pytest.mark.parametrize("sort", [True, False]) def test_dask_cudf_integration( From d16f818599197314b0f2eb3969cf5d599bd0a03e Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 12 Jan 2026 14:31:02 -0800 Subject: [PATCH 3/3] clean up --- .../rapidsmpf/rapidsmpf/integrations/core.py | 77 +++++-------------- python/rapidsmpf/rapidsmpf/tests/test_dask.py | 26 +------ .../rapidsmpf/rapidsmpf/tests/test_utils.py | 46 ++++++++++- python/rapidsmpf/rapidsmpf/utils/string.py | 71 ++++++++++++++++- 4 files changed, 134 insertions(+), 86 deletions(-) diff --git a/python/rapidsmpf/rapidsmpf/integrations/core.py b/python/rapidsmpf/rapidsmpf/integrations/core.py index ecd19ac03..29dc7fb47 100644 --- a/python/rapidsmpf/rapidsmpf/integrations/core.py +++ b/python/rapidsmpf/rapidsmpf/integrations/core.py @@ -26,6 +26,7 @@ from rapidsmpf.rmm_resource_adaptor import RmmResourceAdaptor from rapidsmpf.shuffler import Shuffler from rapidsmpf.statistics import Statistics +from rapidsmpf.utils.string import parse_bytes_threshold if TYPE_CHECKING: from collections.abc import Callable, Sequence @@ -670,42 +671,6 @@ def spill_func( return ctx.spill_collection.spill(amount, stream=DEFAULT_STREAM, device_mr=mr) -def _parse_pool_size( - value: float | None, total_device_memory: int, *, alignment_size: int = 256 -) -> int | None: - """ - Parse a pool size value, supporting fractions of total device memory. - - Parameters - ---------- - value - Can be: - - None: Returns None - - A float in (0, 1]: Interpreted as fraction of total device memory - - A float > 1: Interpreted as byte count - total_device_memory - Total device memory in bytes. - alignment_size - Byte alignment (RMM pools require 256-byte alignment). - - Returns - ------- - Parsed byte count aligned to alignment_size, or None. - """ - if value is None or value == 0: - return None - - if 0.0 < value <= 1.0: - # Fraction of device memory - byte_count = int(total_device_memory * value) - else: - # Already a byte count - byte_count = int(value) - - # Align to alignment_size - return (byte_count // alignment_size) * alignment_size - - def setup_rmm_pool( option_prefix: str, options: Options, @@ -740,40 +705,36 @@ def setup_rmm_pool( # Get total device memory for parsing fractional pool sizes total_device_memory = rmm.mr.available_device_memory()[1] - # Parse RMM configuration options - # Pool sizes follow the same pattern as spill_device: values in (0, 1] are - # fractions of device memory, values > 1 are byte counts. - initial_pool_size_opt = options.get_or_default( - f"{option_prefix}rmm_pool_size", default_value=Optional(None) - ).value - initial_pool_size = _parse_pool_size( - float(initial_pool_size_opt) if initial_pool_size_opt is not None else None, + # Parse RMM configuration options using parse_bytes_threshold. + # Values in (0, 1] are fractions of device memory, values > 1 are byte counts. + # RMM pools require 256-byte alignment. + initial_pool_size = parse_bytes_threshold( + options.get_or_default( + f"{option_prefix}rmm_pool_size", default_value=Optional(None) + ).value, total_device_memory, + alignment=256, ) - - maximum_pool_size_opt = options.get_or_default( - f"{option_prefix}rmm_maximum_pool_size", default_value=Optional(None) - ).value - maximum_pool_size = _parse_pool_size( - float(maximum_pool_size_opt) if maximum_pool_size_opt is not None else None, + maximum_pool_size = parse_bytes_threshold( + options.get_or_default( + f"{option_prefix}rmm_maximum_pool_size", default_value=Optional(None) + ).value, total_device_memory, + alignment=256, ) - managed_memory = options.get_or_default( f"{option_prefix}rmm_managed_memory", default_value=False ) async_alloc = options.get_or_default( f"{option_prefix}rmm_async", default_value=False ) - - release_threshold_opt = options.get_or_default( - f"{option_prefix}rmm_release_threshold", default_value=Optional(None) - ).value - release_threshold = _parse_pool_size( - float(release_threshold_opt) if release_threshold_opt is not None else None, + release_threshold = parse_bytes_threshold( + options.get_or_default( + f"{option_prefix}rmm_release_threshold", default_value=Optional(None) + ).value, total_device_memory, + alignment=256, ) - track_allocations = options.get_or_default( f"{option_prefix}rmm_track_allocations", default_value=False ) diff --git a/python/rapidsmpf/rapidsmpf/tests/test_dask.py b/python/rapidsmpf/rapidsmpf/tests/test_dask.py index 16599cc89..2cfb76d48 100644 --- a/python/rapidsmpf/rapidsmpf/tests/test_dask.py +++ b/python/rapidsmpf/rapidsmpf/tests/test_dask.py @@ -19,7 +19,7 @@ dask_cudf_join, dask_cudf_shuffle, ) -from rapidsmpf.integrations.core import _parse_pool_size, setup_rmm_pool +from rapidsmpf.integrations.core import setup_rmm_pool from rapidsmpf.integrations.dask.core import get_worker_context from rapidsmpf.integrations.dask.shuffler import ( clear_shuffle_statistics, @@ -210,30 +210,6 @@ def test_rmm_setup_validation_errors() -> None: ) -# Use a fixed total memory for deterministic tests -_TOTAL_MEMORY = 16 * 1024 * 1024 * 1024 # 16 GiB - - -@pytest.mark.parametrize( - "value,expected", - [ - # None and zero return None - (None, None), - (0, None), - (0.0, None), - # Fractions (0 < value <= 1) are % of device memory - (0.5, 8 * 1024 * 1024 * 1024), # 50% of 16 GiB = 8 GiB - (1.0, _TOTAL_MEMORY), # 100% of device memory - # Values > 1 are byte counts (aligned to 256) - (1024 * 1024 * 1024, 1024 * 1024 * 1024), # 1 GiB - (1000, 768), # floor(1000 / 256) * 256 - ], -) -def test_parse_pool_size(value: float | None, expected: int | None) -> None: - """Test _parse_pool_size with various inputs.""" - assert _parse_pool_size(value, _TOTAL_MEMORY) == expected - - @pytest.mark.parametrize("partition_count", [None, 3]) @pytest.mark.parametrize("sort", [True, False]) def test_dask_cudf_integration( diff --git a/python/rapidsmpf/rapidsmpf/tests/test_utils.py b/python/rapidsmpf/rapidsmpf/tests/test_utils.py index 6cc255b00..953530a31 100644 --- a/python/rapidsmpf/rapidsmpf/tests/test_utils.py +++ b/python/rapidsmpf/rapidsmpf/tests/test_utils.py @@ -1,10 +1,15 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import pytest -from rapidsmpf.utils.string import format_bytes, parse_boolean, parse_bytes +from rapidsmpf.utils.string import ( + format_bytes, + parse_boolean, + parse_bytes, + parse_bytes_threshold, +) def test_format_bytes() -> None: @@ -41,3 +46,40 @@ def test_parse_boolean_false(input_str: str) -> None: def test_parse_boolean_invalid(invalid_str: str) -> None: with pytest.raises(ValueError, match="Cannot parse boolean"): parse_boolean(invalid_str) + + +# Use a fixed total for deterministic tests +_TOTAL = 16 * 1024 * 1024 * 1024 # 16 GiB + + +@pytest.mark.parametrize( + "value,expected", + [ + # None and zero return None + (None, None), + (0, None), + (0.0, None), + # Fractions (0 < value <= 1) are % of total + (0.5, 8 * 1024 * 1024 * 1024), # 50% of 16 GiB = 8 GiB + (1.0, _TOTAL), # 100% + ("0.5", 8 * 1024 * 1024 * 1024), # String fraction + # Values > 1 are byte counts + (1024 * 1024 * 1024, 1024 * 1024 * 1024), # 1 GiB + ("1000000000", 1000000000), # String byte count + # Byte strings (e.g., "12 GB", "128 MiB") + ("1 GiB", 1024 * 1024 * 1024), + ("12 GB", 12 * 1000 * 1000 * 1000), + ], +) +def test_parse_bytes_threshold(value: str | float | None, expected: int | None) -> None: + assert parse_bytes_threshold(value, _TOTAL) == expected + + +def test_parse_bytes_threshold_alignment() -> None: + # Results should be aligned to the specified alignment + assert ( + parse_bytes_threshold(1000, 1000, alignment=256) == 768 + ) # floor(1000 / 256) * 256 + assert ( + parse_bytes_threshold(0.5, 1000, alignment=100) == 500 + ) # 500 is divisible by 100 diff --git a/python/rapidsmpf/rapidsmpf/utils/string.py b/python/rapidsmpf/rapidsmpf/utils/string.py index 066b1b55d..2184e54a5 100644 --- a/python/rapidsmpf/rapidsmpf/utils/string.py +++ b/python/rapidsmpf/rapidsmpf/utils/string.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 """Useful string utilities.""" @@ -118,6 +118,75 @@ def parse_bytes(s: str | int) -> int: return int(number * unit_multipliers[unit]) +def parse_bytes_threshold( + threshold: str | float | None, + total: int, + *, + alignment: int = 1, +) -> int | None: + """ + Parse a threshold value that can be a fraction, byte count, or byte string. + + Parameters + ---------- + threshold + The threshold value. Can be: + - None: Returns None + - A float/string in (0, 1]: Interpreted as a fraction of `total` + - A float/string > 1: Interpreted as an absolute byte count + - A byte string like "12 GB" or "128 MiB": Parsed as bytes + total + The total size (e.g., total device memory) used to compute + fractional thresholds. + alignment + Byte alignment for the result. + + Returns + ------- + The parsed byte count aligned to `alignment`, or None. + + Examples + -------- + >>> parse_bytes_threshold(None, 1000) + >>> parse_bytes_threshold(0.5, 1000) + 500 + >>> parse_bytes_threshold("0.5", 1000) + 500 + >>> parse_bytes_threshold(100, 1000) + 100 + >>> parse_bytes_threshold("12 GB", 1000) + 12000000000 + >>> parse_bytes_threshold(100, 1000, alignment=256) + 0 + """ + if threshold is None: + return None + + # Try to parse as a byte string (e.g., "12 GB", "128 MiB") + try: + if ( + isinstance(threshold, (str, int)) + and (maybe_bytes := parse_bytes(threshold)) >= 1 + ): + threshold = maybe_bytes + except (TypeError, ValueError): + pass + + value = float(threshold) + if value == 0: + return None + + if 0.0 < value <= 1.0: + # Fraction of total + byte_count = int(total * value) + else: + # Absolute byte count + byte_count = int(value) + + # Align to alignment + return (byte_count // alignment) * alignment + + def parse_boolean(boolean: str) -> bool: """ Parse a string into a boolean value.