diff --git a/cpp/benchmarks/streaming/ndsh/utils.cpp b/cpp/benchmarks/streaming/ndsh/utils.cpp index 9081d55d0..2e421a51f 100644 --- a/cpp/benchmarks/streaming/ndsh/utils.cpp +++ b/cpp/benchmarks/streaming/ndsh/utils.cpp @@ -121,7 +121,6 @@ streaming::TableChunk to_device( std::shared_ptr create_context( ProgramOptions& arguments, RmmResourceAdaptor* mr ) { - rmm::mr::set_current_device_resource(mr); rmm::mr::set_current_device_resource_ref(mr); std::unordered_map memory_available{}; if (arguments.spill_device_limit.has_value()) { diff --git a/cpp/benchmarks/utils/rmm_stack.hpp b/cpp/benchmarks/utils/rmm_stack.hpp index 8ff0b0d38..3e7045eb4 100644 --- a/cpp/benchmarks/utils/rmm_stack.hpp +++ b/cpp/benchmarks/utils/rmm_stack.hpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -44,8 +44,6 @@ set_current_rmm_stack(std::string const& name) { } else { RAPIDSMPF_FAIL("unknown RMM stack name: " + name); } - // Note, RMM maintains two default resources, we set both here. - rmm::mr::set_current_device_resource(ret.get()); rmm::mr::set_current_device_resource_ref(*ret); return ret; } @@ -61,7 +59,6 @@ set_device_mem_resource_with_stats() { auto ret = std::make_shared( cudf::get_current_device_resource_ref() ); - rmm::mr::set_current_device_resource(ret.get()); rmm::mr::set_current_device_resource_ref(*ret); return ret; } diff --git a/cpp/include/rapidsmpf/memory/buffer_resource.hpp b/cpp/include/rapidsmpf/memory/buffer_resource.hpp index 8bc7df73c..2f6254502 100644 --- a/cpp/include/rapidsmpf/memory/buffer_resource.hpp +++ b/cpp/include/rapidsmpf/memory/buffer_resource.hpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: Apache-2.0 */ @@ -13,6 +13,8 @@ #include #include +#include + #include #include @@ -89,7 +91,7 @@ class BufferResource { * * @return Reference to the RMM resource used for device allocations. */ - [[nodiscard]] rmm::device_async_resource_ref device_mr() const noexcept { + [[nodiscard]] rmm::device_async_resource_ref device_mr() noexcept { return device_mr_; } @@ -365,7 +367,7 @@ class BufferResource { private: std::mutex mutex_; - rmm::device_async_resource_ref device_mr_; + cuda::mr::any_resource device_mr_; std::shared_ptr pinned_mr_; HostMemoryResource host_mr_; std::unordered_map memory_available_; diff --git a/cpp/include/rapidsmpf/rmm_resource_adaptor.hpp b/cpp/include/rapidsmpf/rmm_resource_adaptor.hpp index 663fecce8..3a050ed15 100644 --- a/cpp/include/rapidsmpf/rmm_resource_adaptor.hpp +++ b/cpp/include/rapidsmpf/rmm_resource_adaptor.hpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: Apache-2.0 */ @@ -14,6 +14,8 @@ #include #include +#include + #include #include #include @@ -51,7 +53,7 @@ class RmmResourceAdaptor final : public rmm::mr::device_memory_resource { * * @return Reference to the RMM memory resource. */ - [[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept { + [[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() noexcept { return primary_mr_; } @@ -63,8 +65,11 @@ class RmmResourceAdaptor final : public rmm::mr::device_memory_resource { * @return Optional reference to the fallback RMM memory resource. */ [[nodiscard]] std::optional - get_fallback_resource() const noexcept { - return fallback_mr_; + get_fallback_resource() noexcept { + if (fallback_mr_) { + return rmm::device_async_resource_ref{*fallback_mr_}; + } + return std::nullopt; } /** @@ -160,8 +165,8 @@ class RmmResourceAdaptor final : public rmm::mr::device_memory_resource { ) const noexcept override; mutable std::mutex mutex_; - rmm::device_async_resource_ref primary_mr_; - std::optional fallback_mr_; + cuda::mr::any_resource primary_mr_; + std::optional> fallback_mr_; std::unordered_set fallback_allocations_; /// Tracks memory statistics for the lifetime of the resource. diff --git a/cpp/src/rmm_resource_adaptor.cpp b/cpp/src/rmm_resource_adaptor.cpp index 1914e2c57..eab9215d6 100644 --- a/cpp/src/rmm_resource_adaptor.cpp +++ b/cpp/src/rmm_resource_adaptor.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: Apache-2.0 */ @@ -126,19 +126,19 @@ bool RmmResourceAdaptor::do_is_equal( if (cast == nullptr) { return false; } - // Manual comparison of optionals to avoid recursive constraint satisfaction in - // CCCL 3.2. std::optional::operator== triggers infinite concept checking when the - // wrapped type (rmm::device_async_resource_ref) inherits from CCCL's concept-based - // resource_ref. - // TODO: Revert this after the RMM resource ref types are replaced with - // plain cuda::mr ref types. This depends on - // https://github.com/rapidsai/rmm/issues/2011. - auto this_fallback = get_fallback_resource(); - auto other_fallback = cast->get_fallback_resource(); - bool fallbacks_equal = - (this_fallback.has_value() == other_fallback.has_value()) - && (!this_fallback.has_value() || (*this_fallback == *other_fallback)); - return get_upstream_resource() == cast->get_upstream_resource() && fallbacks_equal; + // Compare the owned any_resource members directly. + // Note: We must extract values from std::optional before comparing to avoid + // recursive constraint satisfaction in CCCL 3.2's concept checking when using + // std::optional::operator== with any_resource. + bool fallbacks_equal; + if (fallback_mr_.has_value() != cast->fallback_mr_.has_value()) { + fallbacks_equal = false; + } else if (!fallback_mr_.has_value()) { + fallbacks_equal = true; + } else { + fallbacks_equal = (*fallback_mr_ == *cast->fallback_mr_); + } + return (primary_mr_ == cast->primary_mr_) && fallbacks_equal; } } // namespace rapidsmpf