Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cpp/benchmarks/streaming/ndsh/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ streaming::TableChunk to_device(
std::shared_ptr<streaming::Context> create_context(
ProgramOptions& arguments, RmmResourceAdaptor* mr
) {
rmm::mr::set_current_device_resource(mr);
rmm::mr::set_current_device_resource_ref(mr);
std::unordered_map<MemoryType, BufferResource::MemoryAvailable> memory_available{};
if (arguments.spill_device_limit.has_value()) {
Expand Down
5 changes: 1 addition & 4 deletions cpp/benchmarks/utils/rmm_stack.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
Expand Down Expand Up @@ -44,8 +44,6 @@ set_current_rmm_stack(std::string const& name) {
} else {
RAPIDSMPF_FAIL("unknown RMM stack name: " + name);
}
// Note, RMM maintains two default resources, we set both here.
rmm::mr::set_current_device_resource(ret.get());
rmm::mr::set_current_device_resource_ref(*ret);
return ret;
}
Expand All @@ -61,7 +59,6 @@ set_device_mem_resource_with_stats() {
auto ret = std::make_shared<rapidsmpf::RmmResourceAdaptor>(
cudf::get_current_device_resource_ref()
);
rmm::mr::set_current_device_resource(ret.get());
rmm::mr::set_current_device_resource_ref(*ret);
return ret;
}
8 changes: 5 additions & 3 deletions cpp/include/rapidsmpf/memory/buffer_resource.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -13,6 +13,8 @@
#include <unordered_map>
#include <utility>

#include <cuda/memory_resource>

#include <rmm/cuda_stream_pool.hpp>

#include <rapidsmpf/error.hpp>
Expand Down Expand Up @@ -89,7 +91,7 @@ class BufferResource {
*
* @return Reference to the RMM resource used for device allocations.
*/
[[nodiscard]] rmm::device_async_resource_ref device_mr() const noexcept {
[[nodiscard]] rmm::device_async_resource_ref device_mr() noexcept {
return device_mr_;
}
Comment on lines +94 to 96
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why this "return device_mr_ as resource_ref cannot be const?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s discuss on this RMM PR thread, it is the same question: rapidsai/rmm#2201 (comment)


Expand Down Expand Up @@ -365,7 +367,7 @@ class BufferResource {

private:
std::mutex mutex_;
rmm::device_async_resource_ref device_mr_;
cuda::mr::any_resource<cuda::mr::device_accessible> device_mr_;
std::shared_ptr<PinnedMemoryResource> pinned_mr_;
HostMemoryResource host_mr_;
std::unordered_map<MemoryType, MemoryAvailable> memory_available_;
Expand Down
17 changes: 11 additions & 6 deletions cpp/include/rapidsmpf/rmm_resource_adaptor.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -14,6 +14,8 @@
#include <unordered_set>
#include <utility>

#include <cuda/memory_resource>

#include <rmm/error.hpp>
#include <rmm/mr/device_memory_resource.hpp>
#include <rmm/resource_ref.hpp>
Expand Down Expand Up @@ -51,7 +53,7 @@ class RmmResourceAdaptor final : public rmm::mr::device_memory_resource {
*
* @return Reference to the RMM memory resource.
*/
[[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept {
[[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() noexcept {
return primary_mr_;
}

Expand All @@ -63,8 +65,11 @@ class RmmResourceAdaptor final : public rmm::mr::device_memory_resource {
* @return Optional reference to the fallback RMM memory resource.
*/
[[nodiscard]] std::optional<rmm::device_async_resource_ref>
get_fallback_resource() const noexcept {
return fallback_mr_;
get_fallback_resource() noexcept {
if (fallback_mr_) {
return rmm::device_async_resource_ref{*fallback_mr_};
}
return std::nullopt;
}

/**
Expand Down Expand Up @@ -160,8 +165,8 @@ class RmmResourceAdaptor final : public rmm::mr::device_memory_resource {
) const noexcept override;

mutable std::mutex mutex_;
rmm::device_async_resource_ref primary_mr_;
std::optional<rmm::device_async_resource_ref> fallback_mr_;
cuda::mr::any_resource<cuda::mr::device_accessible> primary_mr_;
std::optional<cuda::mr::any_resource<cuda::mr::device_accessible>> fallback_mr_;
std::unordered_set<void*> fallback_allocations_;

/// Tracks memory statistics for the lifetime of the resource.
Expand Down
28 changes: 14 additions & 14 deletions cpp/src/rmm_resource_adaptor.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -126,19 +126,19 @@ bool RmmResourceAdaptor::do_is_equal(
if (cast == nullptr) {
return false;
}
// Manual comparison of optionals to avoid recursive constraint satisfaction in
// CCCL 3.2. std::optional::operator== triggers infinite concept checking when the
// wrapped type (rmm::device_async_resource_ref) inherits from CCCL's concept-based
// resource_ref.
// TODO: Revert this after the RMM resource ref types are replaced with
// plain cuda::mr ref types. This depends on
// https://github.com/rapidsai/rmm/issues/2011.
auto this_fallback = get_fallback_resource();
auto other_fallback = cast->get_fallback_resource();
bool fallbacks_equal =
(this_fallback.has_value() == other_fallback.has_value())
&& (!this_fallback.has_value() || (*this_fallback == *other_fallback));
return get_upstream_resource() == cast->get_upstream_resource() && fallbacks_equal;
// Compare the owned any_resource members directly.
// Note: We must extract values from std::optional before comparing to avoid
// recursive constraint satisfaction in CCCL 3.2's concept checking when using
// std::optional::operator== with any_resource.
bool fallbacks_equal;
if (fallback_mr_.has_value() != cast->fallback_mr_.has_value()) {
fallbacks_equal = false;
} else if (!fallback_mr_.has_value()) {
fallbacks_equal = true;
} else {
fallbacks_equal = (*fallback_mr_ == *cast->fallback_mr_);
}
Comment on lines +133 to +140
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the old logic that initialises fallbacks_equal immediately, is there a functional reason for this change? I think no.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll investigate but I think it was failing that way.

return (primary_mr_ == cast->primary_mr_) && fallbacks_equal;
}

} // namespace rapidsmpf