From f7bd315f50c9312675d5ee815c7454622aff090d Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 19 Nov 2025 17:17:05 -0800 Subject: [PATCH 1/9] capturing data sizes Signed-off-by: niranda perera --- cpp/include/rapidsmpf/shuffler/postbox.hpp | 29 +++++ cpp/src/shuffler/postbox.cpp | 29 +++-- cpp/tests/test_shuffler.cpp | 117 ++++++++++++++++----- 3 files changed, 140 insertions(+), 35 deletions(-) diff --git a/cpp/include/rapidsmpf/shuffler/postbox.hpp b/cpp/include/rapidsmpf/shuffler/postbox.hpp index 48876d949..5e98cc367 100644 --- a/cpp/include/rapidsmpf/shuffler/postbox.hpp +++ b/cpp/include/rapidsmpf/shuffler/postbox.hpp @@ -4,6 +4,7 @@ */ #pragma once +#include #include #include #include @@ -124,13 +125,41 @@ class PostBox { */ [[nodiscard]] std::string str() const; + /** + * @brief Returns the size of the data in the specified memory type. + * + * @param mem_type The type of memory to query. + * @return The size of the data in the specified memory type. + */ + [[nodiscard]] constexpr size_t data_size(MemoryType mem_type) const { + return data_size_[static_cast(mem_type)]; + } + private: + constexpr size_t& data_size_ref(MemoryType mem_type) { + return data_size_[static_cast(mem_type)]; + } + + void increment_data_size(Chunk const& chunk) { + if (chunk.is_data_buffer_set()) { + data_size_ref(chunk.data_memory_type()) += chunk.concat_data_size(); + } + } + + void decrement_data_size(Chunk const& chunk) { + if (chunk.is_data_buffer_set()) { + data_size_ref(chunk.data_memory_type()) -= chunk.concat_data_size(); + } + } + // TODO: more fine-grained locking e.g. by locking each partition individually. mutable std::mutex mutex_; std::function key_map_fn_; ///< Function to map partition IDs to keys. std::unordered_map> pigeonhole_; ///< Storage for chunks, organized by a key and chunk ID. + + std::array data_size_{}; }; /** diff --git a/cpp/src/shuffler/postbox.cpp b/cpp/src/shuffler/postbox.cpp index 4450ba517..80b2185bc 100644 --- a/cpp/src/shuffler/postbox.cpp +++ b/cpp/src/shuffler/postbox.cpp @@ -23,10 +23,9 @@ void PostBox::insert(Chunk&& chunk) { ); } std::lock_guard const lock(mutex_); - RAPIDSMPF_EXPECTS( - pigeonhole_[key].emplace(chunk.chunk_id(), std::move(chunk)).second, - "PostBox.insert(): chunk already exist" - ); + auto [it, inserted] = pigeonhole_[key].emplace(chunk.chunk_id(), std::move(chunk)); + RAPIDSMPF_EXPECTS(inserted, "PostBox.insert(): chunk already exist"); + increment_data_size(it->second); } template @@ -38,19 +37,29 @@ bool PostBox::is_empty(PartID pid) const { template Chunk PostBox::extract(PartID pid, ChunkID cid) { std::lock_guard const lock(mutex_); - return extract_item(pigeonhole_[key_map_fn_(pid)], cid).second; + auto chunk = std::move(extract_item(pigeonhole_[key_map_fn_(pid)], cid).second); + decrement_data_size(chunk); + return chunk; } template std::unordered_map PostBox::extract(PartID pid) { std::lock_guard const lock(mutex_); - return extract_value(pigeonhole_, key_map_fn_(pid)); + auto chunks = extract_value(pigeonhole_, key_map_fn_(pid)); + for (auto const& [cid, chunk] : chunks) { + decrement_data_size(chunk); + } + return chunks; } template std::unordered_map PostBox::extract_by_key(KeyType key) { std::lock_guard const lock(mutex_); - return extract_value(pigeonhole_, key); + auto chunks = extract_value(pigeonhole_, key); + for (auto const& [cid, chunk] : chunks) { + decrement_data_size(chunk); + } + return chunks; } template @@ -65,8 +74,10 @@ std::vector PostBox::extract_all_ready() { auto& chunks = pid_it->second; auto chunk_it = chunks.begin(); while (chunk_it != chunks.end()) { - if (chunk_it->second.is_ready()) { - ret.emplace_back(std::move(chunk_it->second)); + auto&& chunk = chunk_it->second; + if (chunk.is_ready()) { + decrement_data_size(chunk); + ret.emplace_back(std::move(chunk)); chunk_it = chunks.erase(chunk_it); } else { ++chunk_it; diff --git a/cpp/tests/test_shuffler.cpp b/cpp/tests/test_shuffler.cpp index 4a511f1a4..99eaa2a2f 100644 --- a/cpp/tests/test_shuffler.cpp +++ b/cpp/tests/test_shuffler.cpp @@ -372,19 +372,17 @@ class ConcurrentShuffleTest for (int t_id = 0; t_id < num_shufflers; t_id++) { // pass a copy of the insert_fn and insert_finished_fn to each thread - futures.push_back( - std::async( - std::launch::async, - [this, - t_id, - insert_fn1 = insert_fn, - insert_finished_fn1 = insert_finished_fn] { - ASSERT_NO_FATAL_FAILURE(this->RunTest( - t_id, std::move(insert_fn1), std::move(insert_finished_fn1) - )); - } - ) - ); + futures.push_back(std::async( + std::launch::async, + [this, + t_id, + insert_fn1 = insert_fn, + insert_finished_fn1 = insert_finished_fn] { + ASSERT_NO_FATAL_FAILURE(this->RunTest( + t_id, std::move(insert_fn1), std::move(insert_finished_fn1) + )); + } + )); } for (auto& f : futures) { @@ -484,9 +482,9 @@ class ShuffleInsertGroupedTest stream = cudf::get_default_stream(); - progress_thread = std::make_shared( - GlobalEnvironment->comm_->logger() - ); + progress_thread = + std::make_shared(GlobalEnvironment->comm_->logger() + ); GlobalEnvironment->barrier(); } @@ -704,9 +702,8 @@ TEST(Shuffler, SpillOnInsertAndExtraction) { rapidsmpf::BufferResource br{ mr, {{rapidsmpf::MemoryType::DEVICE, - [&device_memory_available]() -> std::int64_t { - return device_memory_available; - }}}, + [&device_memory_available]() -> std::int64_t { return device_memory_available; } + }}, std::nullopt // disable periodic spill check }; EXPECT_EQ( @@ -883,7 +880,9 @@ class FinishCounterMultithreadingTest n_finished_pids = 0; finish_counter = std::make_unique( - nranks, local_partitions, [&](rapidsmpf::shuffler::PartID pid) { + nranks, + local_partitions, + [&](rapidsmpf::shuffler::PartID pid) { { std::lock_guard lock(mtx); finished_pids.push_back(pid); @@ -1037,7 +1036,13 @@ class PostBoxTest : public cudf::test::BaseFixture { postbox.reset(); } + std::unique_ptr make_empty_buffer() { + return br.allocate(stream, br.reserve_or_fail(0, rapidsmpf::MemoryType::HOST)); + } + std::unique_ptr postbox; + rmm::cuda_stream_view stream{cudf::get_default_stream()}; + rapidsmpf::BufferResource br{cudf::get_current_device_resource_ref()}; }; TEST_F(PostBoxTest, EmptyPostbox) { @@ -1056,6 +1061,7 @@ TEST_F(PostBoxTest, InsertAndExtractMultipleChunks) { rapidsmpf::shuffler::detail::ChunkID{i}, rapidsmpf::shuffler::PartID{i % num_partitions} ); + chunk.set_data_buffer(make_empty_buffer()); postbox->insert(std::move(chunk)); } @@ -1099,6 +1105,7 @@ TEST_F(PostBoxTest, ThreadSafety) { rapidsmpf::shuffler::detail::ChunkID{i * chunks_per_thread + j}, rapidsmpf::shuffler::PartID{j / chunks_per_partition} ); + chunk.set_data_buffer(make_empty_buffer()); postbox->insert(std::move(chunk)); } }); @@ -1119,6 +1126,66 @@ TEST_F(PostBoxTest, ThreadSafety) { EXPECT_TRUE(postbox->empty()); } +TEST_F(PostBoxTest, DataSize) { + // Helper method to create packed data with specified sizes and memory type + auto create_packed_data = + [&](size_t metadata_size, size_t buffer_size, rapidsmpf::MemoryType mem_type + ) -> rapidsmpf::PackedData { + auto metadata = std::make_unique>(metadata_size); + auto data_reservation = br.reserve_or_fail(buffer_size, mem_type); + auto data = br.allocate(stream, std::move(data_reservation)); + return rapidsmpf::PackedData{std::move(metadata), std::move(data)}; + }; + + // Initially, both device and host data sizes should be 0 + EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::DEVICE), 0); + EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::HOST), 0); + + // Create a chunk with device memory (100 bytes of data) + auto chunk1 = rapidsmpf::shuffler::detail::Chunk::from_packed_data( + 1, 0, create_packed_data(50, 100, rapidsmpf::MemoryType::DEVICE) + ); + postbox->insert(std::move(chunk1)); + + // Device data size should be 100, host should still be 0 + EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::DEVICE), 100); + EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::HOST), 0); + + // Create another chunk with device memory (200 bytes of data) + auto chunk2 = rapidsmpf::shuffler::detail::Chunk::from_packed_data( + 2, 1, create_packed_data(30, 200, rapidsmpf::MemoryType::DEVICE) + ); + postbox->insert(std::move(chunk2)); + + // Device data size should now be 300 (100 + 200) + EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::DEVICE), 300); + EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::HOST), 0); + + // Create a chunk with host memory (150 bytes of data) + auto chunk3 = rapidsmpf::shuffler::detail::Chunk::from_packed_data( + 3, 2, create_packed_data(40, 150, rapidsmpf::MemoryType::HOST) + ); + + postbox->insert(std::move(chunk3)); + + // Device should still be 300, host should be 150 + EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::DEVICE), 300); + EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::HOST), 150); + + // Extract a chunk and verify sizes decrease + std::ignore = postbox->extract(0, 1); + EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::DEVICE), 200); + EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::HOST), 150); + + // Extract all remaining chunks + std::ignore = postbox->extract_all_ready(); + + // Both should be 0 after extracting all + EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::DEVICE), 0); + EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::HOST), 0); + EXPECT_TRUE(postbox->empty()); +} + TEST(Shuffler, ShutdownWhilePaused) { auto progress_thread = std::make_shared(GlobalEnvironment->comm_->logger()); @@ -1265,9 +1332,8 @@ TEST_F(ExtractEmptyPartitionsTest, SomeEmptyAndNonEmptyInsertions) { } insert_chunks(std::move(chunks)); - EXPECT_NO_FATAL_FAILURE(verify_extracted_chunks([](auto pid) { - return pid % 3 == 0; - })); + EXPECT_NO_FATAL_FAILURE(verify_extracted_chunks([](auto pid) { return pid % 3 == 0; }) + ); } TEST(ShufflerTest, multiple_shutdowns) { @@ -1284,9 +1350,8 @@ TEST(ShufflerTest, multiple_shutdowns) { constexpr int n_threads = 10; std::vector> futures; for (int i = 0; i < n_threads; ++i) { - futures.emplace_back(std::async(std::launch::async, [&] { - shuffler->shutdown(); - })); + futures.emplace_back(std::async(std::launch::async, [&] { shuffler->shutdown(); }) + ); } std::ranges::for_each(futures, [](auto& future) { future.get(); }); shuffler.reset(); From c7e7ea2e3afe357a575c78f9c9e1c81ee66cc00d Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 20 Nov 2025 16:28:52 -0800 Subject: [PATCH 2/9] Revert "capturing data sizes" This reverts commit f7bd315f50c9312675d5ee815c7454622aff090d. --- cpp/include/rapidsmpf/shuffler/postbox.hpp | 29 ----- cpp/src/shuffler/postbox.cpp | 29 ++--- cpp/tests/test_shuffler.cpp | 117 +++++---------------- 3 files changed, 35 insertions(+), 140 deletions(-) diff --git a/cpp/include/rapidsmpf/shuffler/postbox.hpp b/cpp/include/rapidsmpf/shuffler/postbox.hpp index 5e98cc367..48876d949 100644 --- a/cpp/include/rapidsmpf/shuffler/postbox.hpp +++ b/cpp/include/rapidsmpf/shuffler/postbox.hpp @@ -4,7 +4,6 @@ */ #pragma once -#include #include #include #include @@ -125,41 +124,13 @@ class PostBox { */ [[nodiscard]] std::string str() const; - /** - * @brief Returns the size of the data in the specified memory type. - * - * @param mem_type The type of memory to query. - * @return The size of the data in the specified memory type. - */ - [[nodiscard]] constexpr size_t data_size(MemoryType mem_type) const { - return data_size_[static_cast(mem_type)]; - } - private: - constexpr size_t& data_size_ref(MemoryType mem_type) { - return data_size_[static_cast(mem_type)]; - } - - void increment_data_size(Chunk const& chunk) { - if (chunk.is_data_buffer_set()) { - data_size_ref(chunk.data_memory_type()) += chunk.concat_data_size(); - } - } - - void decrement_data_size(Chunk const& chunk) { - if (chunk.is_data_buffer_set()) { - data_size_ref(chunk.data_memory_type()) -= chunk.concat_data_size(); - } - } - // TODO: more fine-grained locking e.g. by locking each partition individually. mutable std::mutex mutex_; std::function key_map_fn_; ///< Function to map partition IDs to keys. std::unordered_map> pigeonhole_; ///< Storage for chunks, organized by a key and chunk ID. - - std::array data_size_{}; }; /** diff --git a/cpp/src/shuffler/postbox.cpp b/cpp/src/shuffler/postbox.cpp index 80b2185bc..4450ba517 100644 --- a/cpp/src/shuffler/postbox.cpp +++ b/cpp/src/shuffler/postbox.cpp @@ -23,9 +23,10 @@ void PostBox::insert(Chunk&& chunk) { ); } std::lock_guard const lock(mutex_); - auto [it, inserted] = pigeonhole_[key].emplace(chunk.chunk_id(), std::move(chunk)); - RAPIDSMPF_EXPECTS(inserted, "PostBox.insert(): chunk already exist"); - increment_data_size(it->second); + RAPIDSMPF_EXPECTS( + pigeonhole_[key].emplace(chunk.chunk_id(), std::move(chunk)).second, + "PostBox.insert(): chunk already exist" + ); } template @@ -37,29 +38,19 @@ bool PostBox::is_empty(PartID pid) const { template Chunk PostBox::extract(PartID pid, ChunkID cid) { std::lock_guard const lock(mutex_); - auto chunk = std::move(extract_item(pigeonhole_[key_map_fn_(pid)], cid).second); - decrement_data_size(chunk); - return chunk; + return extract_item(pigeonhole_[key_map_fn_(pid)], cid).second; } template std::unordered_map PostBox::extract(PartID pid) { std::lock_guard const lock(mutex_); - auto chunks = extract_value(pigeonhole_, key_map_fn_(pid)); - for (auto const& [cid, chunk] : chunks) { - decrement_data_size(chunk); - } - return chunks; + return extract_value(pigeonhole_, key_map_fn_(pid)); } template std::unordered_map PostBox::extract_by_key(KeyType key) { std::lock_guard const lock(mutex_); - auto chunks = extract_value(pigeonhole_, key); - for (auto const& [cid, chunk] : chunks) { - decrement_data_size(chunk); - } - return chunks; + return extract_value(pigeonhole_, key); } template @@ -74,10 +65,8 @@ std::vector PostBox::extract_all_ready() { auto& chunks = pid_it->second; auto chunk_it = chunks.begin(); while (chunk_it != chunks.end()) { - auto&& chunk = chunk_it->second; - if (chunk.is_ready()) { - decrement_data_size(chunk); - ret.emplace_back(std::move(chunk)); + if (chunk_it->second.is_ready()) { + ret.emplace_back(std::move(chunk_it->second)); chunk_it = chunks.erase(chunk_it); } else { ++chunk_it; diff --git a/cpp/tests/test_shuffler.cpp b/cpp/tests/test_shuffler.cpp index 99eaa2a2f..4a511f1a4 100644 --- a/cpp/tests/test_shuffler.cpp +++ b/cpp/tests/test_shuffler.cpp @@ -372,17 +372,19 @@ class ConcurrentShuffleTest for (int t_id = 0; t_id < num_shufflers; t_id++) { // pass a copy of the insert_fn and insert_finished_fn to each thread - futures.push_back(std::async( - std::launch::async, - [this, - t_id, - insert_fn1 = insert_fn, - insert_finished_fn1 = insert_finished_fn] { - ASSERT_NO_FATAL_FAILURE(this->RunTest( - t_id, std::move(insert_fn1), std::move(insert_finished_fn1) - )); - } - )); + futures.push_back( + std::async( + std::launch::async, + [this, + t_id, + insert_fn1 = insert_fn, + insert_finished_fn1 = insert_finished_fn] { + ASSERT_NO_FATAL_FAILURE(this->RunTest( + t_id, std::move(insert_fn1), std::move(insert_finished_fn1) + )); + } + ) + ); } for (auto& f : futures) { @@ -482,9 +484,9 @@ class ShuffleInsertGroupedTest stream = cudf::get_default_stream(); - progress_thread = - std::make_shared(GlobalEnvironment->comm_->logger() - ); + progress_thread = std::make_shared( + GlobalEnvironment->comm_->logger() + ); GlobalEnvironment->barrier(); } @@ -702,8 +704,9 @@ TEST(Shuffler, SpillOnInsertAndExtraction) { rapidsmpf::BufferResource br{ mr, {{rapidsmpf::MemoryType::DEVICE, - [&device_memory_available]() -> std::int64_t { return device_memory_available; } - }}, + [&device_memory_available]() -> std::int64_t { + return device_memory_available; + }}}, std::nullopt // disable periodic spill check }; EXPECT_EQ( @@ -880,9 +883,7 @@ class FinishCounterMultithreadingTest n_finished_pids = 0; finish_counter = std::make_unique( - nranks, - local_partitions, - [&](rapidsmpf::shuffler::PartID pid) { + nranks, local_partitions, [&](rapidsmpf::shuffler::PartID pid) { { std::lock_guard lock(mtx); finished_pids.push_back(pid); @@ -1036,13 +1037,7 @@ class PostBoxTest : public cudf::test::BaseFixture { postbox.reset(); } - std::unique_ptr make_empty_buffer() { - return br.allocate(stream, br.reserve_or_fail(0, rapidsmpf::MemoryType::HOST)); - } - std::unique_ptr postbox; - rmm::cuda_stream_view stream{cudf::get_default_stream()}; - rapidsmpf::BufferResource br{cudf::get_current_device_resource_ref()}; }; TEST_F(PostBoxTest, EmptyPostbox) { @@ -1061,7 +1056,6 @@ TEST_F(PostBoxTest, InsertAndExtractMultipleChunks) { rapidsmpf::shuffler::detail::ChunkID{i}, rapidsmpf::shuffler::PartID{i % num_partitions} ); - chunk.set_data_buffer(make_empty_buffer()); postbox->insert(std::move(chunk)); } @@ -1105,7 +1099,6 @@ TEST_F(PostBoxTest, ThreadSafety) { rapidsmpf::shuffler::detail::ChunkID{i * chunks_per_thread + j}, rapidsmpf::shuffler::PartID{j / chunks_per_partition} ); - chunk.set_data_buffer(make_empty_buffer()); postbox->insert(std::move(chunk)); } }); @@ -1126,66 +1119,6 @@ TEST_F(PostBoxTest, ThreadSafety) { EXPECT_TRUE(postbox->empty()); } -TEST_F(PostBoxTest, DataSize) { - // Helper method to create packed data with specified sizes and memory type - auto create_packed_data = - [&](size_t metadata_size, size_t buffer_size, rapidsmpf::MemoryType mem_type - ) -> rapidsmpf::PackedData { - auto metadata = std::make_unique>(metadata_size); - auto data_reservation = br.reserve_or_fail(buffer_size, mem_type); - auto data = br.allocate(stream, std::move(data_reservation)); - return rapidsmpf::PackedData{std::move(metadata), std::move(data)}; - }; - - // Initially, both device and host data sizes should be 0 - EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::DEVICE), 0); - EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::HOST), 0); - - // Create a chunk with device memory (100 bytes of data) - auto chunk1 = rapidsmpf::shuffler::detail::Chunk::from_packed_data( - 1, 0, create_packed_data(50, 100, rapidsmpf::MemoryType::DEVICE) - ); - postbox->insert(std::move(chunk1)); - - // Device data size should be 100, host should still be 0 - EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::DEVICE), 100); - EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::HOST), 0); - - // Create another chunk with device memory (200 bytes of data) - auto chunk2 = rapidsmpf::shuffler::detail::Chunk::from_packed_data( - 2, 1, create_packed_data(30, 200, rapidsmpf::MemoryType::DEVICE) - ); - postbox->insert(std::move(chunk2)); - - // Device data size should now be 300 (100 + 200) - EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::DEVICE), 300); - EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::HOST), 0); - - // Create a chunk with host memory (150 bytes of data) - auto chunk3 = rapidsmpf::shuffler::detail::Chunk::from_packed_data( - 3, 2, create_packed_data(40, 150, rapidsmpf::MemoryType::HOST) - ); - - postbox->insert(std::move(chunk3)); - - // Device should still be 300, host should be 150 - EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::DEVICE), 300); - EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::HOST), 150); - - // Extract a chunk and verify sizes decrease - std::ignore = postbox->extract(0, 1); - EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::DEVICE), 200); - EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::HOST), 150); - - // Extract all remaining chunks - std::ignore = postbox->extract_all_ready(); - - // Both should be 0 after extracting all - EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::DEVICE), 0); - EXPECT_EQ(postbox->data_size(rapidsmpf::MemoryType::HOST), 0); - EXPECT_TRUE(postbox->empty()); -} - TEST(Shuffler, ShutdownWhilePaused) { auto progress_thread = std::make_shared(GlobalEnvironment->comm_->logger()); @@ -1332,8 +1265,9 @@ TEST_F(ExtractEmptyPartitionsTest, SomeEmptyAndNonEmptyInsertions) { } insert_chunks(std::move(chunks)); - EXPECT_NO_FATAL_FAILURE(verify_extracted_chunks([](auto pid) { return pid % 3 == 0; }) - ); + EXPECT_NO_FATAL_FAILURE(verify_extracted_chunks([](auto pid) { + return pid % 3 == 0; + })); } TEST(ShufflerTest, multiple_shutdowns) { @@ -1350,8 +1284,9 @@ TEST(ShufflerTest, multiple_shutdowns) { constexpr int n_threads = 10; std::vector> futures; for (int i = 0; i < n_threads; ++i) { - futures.emplace_back(std::async(std::launch::async, [&] { shuffler->shutdown(); }) - ); + futures.emplace_back(std::async(std::launch::async, [&] { + shuffler->shutdown(); + })); } std::ranges::for_each(futures, [](auto& future) { future.get(); }); shuffler.reset(); From e0b7467605bebae5b71f7c1b070d8f92c15f5a85 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 21 Nov 2025 15:12:48 -0800 Subject: [PATCH 3/9] finer locking without spilling Signed-off-by: niranda perera --- cpp/include/rapidsmpf/shuffler/chunk.hpp | 41 ++++++++ cpp/include/rapidsmpf/shuffler/postbox.hpp | 72 +++++++------- cpp/include/rapidsmpf/shuffler/shuffler.hpp | 9 +- cpp/src/shuffler/postbox.cpp | 104 ++++++++++---------- cpp/src/shuffler/shuffler.cpp | 87 +++------------- cpp/tests/test_shuffler.cpp | 14 +-- 6 files changed, 156 insertions(+), 171 deletions(-) diff --git a/cpp/include/rapidsmpf/shuffler/chunk.hpp b/cpp/include/rapidsmpf/shuffler/chunk.hpp index a8cc92384..a243eb3cf 100644 --- a/cpp/include/rapidsmpf/shuffler/chunk.hpp +++ b/cpp/include/rapidsmpf/shuffler/chunk.hpp @@ -469,3 +469,44 @@ inline std::ostream& operator<<(std::ostream& os, ReadyForDataMessage const& obj } // namespace detail } // namespace rapidsmpf::shuffler + +// Custom hash function for Chunk that uses chunk ID +namespace std { +template <> +/** + * @brief Hash function for Chunk. + */ +struct hash { + /** + * @brief Hash function for Chunk that uses chunk ID. + * + * @param chunk The chunk to hash. + * @return The hash of the chunk. + */ + std::size_t operator()( + rapidsmpf::shuffler::detail::Chunk const& chunk + ) const noexcept { + return std::hash{}(chunk.chunk_id()); + } +}; + +template <> +/** + * @brief Equality operator for Chunk. + */ +struct equal_to { + /** + * @brief Equality operator for Chunk that uses chunk ID. + * + * @param lhs The left chunk. + * @param rhs The right chunk. + * @return True if the chunks are equal, false otherwise. + */ + bool operator()( + rapidsmpf::shuffler::detail::Chunk const& lhs, + rapidsmpf::shuffler::detail::Chunk const& rhs + ) const noexcept { + return lhs.chunk_id() == rhs.chunk_id(); + } +}; +} // namespace std diff --git a/cpp/include/rapidsmpf/shuffler/postbox.hpp b/cpp/include/rapidsmpf/shuffler/postbox.hpp index 48876d949..5450b0ade 100644 --- a/cpp/include/rapidsmpf/shuffler/postbox.hpp +++ b/cpp/include/rapidsmpf/shuffler/postbox.hpp @@ -4,13 +4,17 @@ */ #pragma once +#include #include #include +#include #include #include +#include #include #include +#include #include namespace rapidsmpf::shuffler::detail { @@ -31,16 +35,21 @@ class PostBox { * * @tparam Fn The type of the function that maps a partition ID to a key. * @param key_map_fn A function that maps a partition ID to a key. - * @param num_keys_hint The number of keys to reserve space for. + * @param keys The keys expected to be used in the PostBox. * * @note The `key_map_fn` must be convertible to a function that takes a `PartID` and * returns a `KeyType`. */ - template - PostBox(Fn&& key_map_fn, size_t num_keys_hint = 0) - : key_map_fn_(std::move(key_map_fn)) { - if (num_keys_hint > 0) { - pigeonhole_.reserve(num_keys_hint); + template + requires std::convertible_to, KeyType> + PostBox(Fn&& key_map_fn, Range&& keys) : key_map_fn_(std::move(key_map_fn)) { + pigeonhole_.reserve(std::ranges::size(keys)); + for (const auto& key : keys) { + pigeonhole_.emplace( + std::piecewise_construct, + std::forward_as_tuple(key), + std::forward_as_tuple() + ); } } @@ -62,17 +71,6 @@ class PostBox { */ bool is_empty(PartID pid) const; - /** - * @brief Extracts a specific chunk from the PostBox. - * - * @param pid The ID of the partition containing the chunk. - * @param cid The ID of the chunk to be accessed. - * @return The extracted chunk. - * - * @throws std::out_of_range If the chunk is not found. - */ - [[nodiscard]] Chunk extract(PartID pid, ChunkID cid); - /** * @brief Extracts all chunks associated with a specific partition. * @@ -81,7 +79,7 @@ class PostBox { * * @throws std::out_of_range If the partition is not found. */ - std::unordered_map extract(PartID pid); + std::vector extract(PartID pid); /** * @brief Extracts all chunks associated with a specific key. @@ -91,7 +89,7 @@ class PostBox { * * @throws std::out_of_range If the key is not found. */ - std::unordered_map extract_by_key(KeyType key); + std::vector extract_by_key(KeyType key); /** * @brief Extracts all ready chunks from the PostBox. @@ -107,30 +105,36 @@ class PostBox { */ [[nodiscard]] bool empty() const; - /** - * @brief Searches for chunks of the specified memory type. - * - * @param mem_type The type of memory to search within. - * @return A vector of tuples, where each tuple contains: PartID, ChunkID, and the - * size of the chunk. - */ - [[nodiscard]] std::vector> search( - MemoryType mem_type - ) const; - /** * @brief Returns a description of this instance. * @return The description. */ [[nodiscard]] std::string str() const; + /** + * @brief Spills the specified amount of data from the PostBox. + * + * @param br Buffer resource to use for spilling. + * @param amount The amount of data to spill. + * @return The amount of data spilled. + */ + size_t spill(BufferResource* br, size_t amount); + private: - // TODO: more fine-grained locking e.g. by locking each partition individually. - mutable std::mutex mutex_; + /** + * @brief Map value for the PostBox. + * + * @note The mutex is used to protect the chunks set. + */ + struct MapValue { + mutable std::mutex mutex; ///< Mutex to protect each key + std::unordered_set chunks; ///< Set of chunks for the key + }; + std::function key_map_fn_; ///< Function to map partition IDs to keys. - std::unordered_map> - pigeonhole_; ///< Storage for chunks, organized by a key and chunk ID. + std::unordered_map pigeonhole_; ///< Storage for chunks + std::atomic n_non_empty_keys_{0}; }; /** diff --git a/cpp/include/rapidsmpf/shuffler/shuffler.hpp b/cpp/include/rapidsmpf/shuffler/shuffler.hpp index 4e3b66653..e3a063a42 100644 --- a/cpp/include/rapidsmpf/shuffler/shuffler.hpp +++ b/cpp/include/rapidsmpf/shuffler/shuffler.hpp @@ -341,29 +341,26 @@ class Shuffler { private: BufferResource* br_; + std::shared_ptr comm_; std::atomic active_{true}; + std::vector const local_partitions_; + detail::PostBox outgoing_postbox_; ///< Postbox for outgoing chunks, that are ///< ready to be sent to other ranks. detail::PostBox ready_postbox_; ///< Postbox for received chunks, that are ///< ready to be extracted by the user. - std::shared_ptr comm_; std::shared_ptr progress_thread_; ProgressThread::FunctionID progress_thread_function_id_; OpID const op_id_; SpillManager::SpillFunctionID spill_function_id_; - std::vector const local_partitions_; detail::FinishCounter finish_counter_; std::unordered_map outbound_chunk_counter_; mutable std::mutex outbound_chunk_counter_mutex_; - // We protect ready_postbox extraction to avoid returning a chunk that is in the - // process of being spilled by `Shuffler::spill`. - mutable std::mutex ready_postbox_spilling_mutex_; - std::atomic chunk_id_counter_{0}; std::shared_ptr statistics_; diff --git a/cpp/src/shuffler/postbox.cpp b/cpp/src/shuffler/postbox.cpp index 4450ba517..2d2f5265b 100644 --- a/cpp/src/shuffler/postbox.cpp +++ b/cpp/src/shuffler/postbox.cpp @@ -22,88 +22,92 @@ void PostBox::insert(Chunk&& chunk) { "PostBox.insert(): all messages in the chunk must map to the same key" ); } - std::lock_guard const lock(mutex_); + // std::lock_guard const lock(mutex_); + auto& map_value = pigeonhole_.at(key); + std::lock_guard lock(map_value.mutex); + if (map_value.chunks.empty()) { + RAPIDSMPF_EXPECTS( + n_non_empty_keys_.fetch_add(1, std::memory_order_relaxed) + 1 + <= pigeonhole_.size(), + "PostBox.insert(): n_non_empty_keys_ is already at the maximum" + ); + } RAPIDSMPF_EXPECTS( - pigeonhole_[key].emplace(chunk.chunk_id(), std::move(chunk)).second, + map_value.chunks.emplace(std::move(chunk)).second, "PostBox.insert(): chunk already exist" ); } template bool PostBox::is_empty(PartID pid) const { - std::lock_guard const lock(mutex_); - return !pigeonhole_.contains(key_map_fn_(pid)); + auto& map_value = pigeonhole_.at(key_map_fn_(pid)); + std::lock_guard lock(map_value.mutex); + return map_value.chunks.empty(); } template -Chunk PostBox::extract(PartID pid, ChunkID cid) { - std::lock_guard const lock(mutex_); - return extract_item(pigeonhole_[key_map_fn_(pid)], cid).second; +std::vector PostBox::extract(PartID pid) { + return extract_by_key(key_map_fn_(pid)); } template -std::unordered_map PostBox::extract(PartID pid) { - std::lock_guard const lock(mutex_); - return extract_value(pigeonhole_, key_map_fn_(pid)); -} +std::vector PostBox::extract_by_key(KeyType key) { + auto& map_value = pigeonhole_.at(key); + std::vector ret; + std::lock_guard lock(map_value.mutex); + RAPIDSMPF_EXPECTS(!map_value.chunks.empty(), "PostBox.extract(): partition is empty"); + ret.reserve(map_value.chunks.size()); -template -std::unordered_map PostBox::extract_by_key(KeyType key) { - std::lock_guard const lock(mutex_); - return extract_value(pigeonhole_, key); + for (auto it = map_value.chunks.begin(); it != map_value.chunks.end();) { + auto node = map_value.chunks.extract(it++); + ret.emplace_back(std::move(node.value())); + } + + RAPIDSMPF_EXPECTS( + n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0, + "PostBox.extract(): n_non_empty_keys_ is already 0" + ); + return ret; } template std::vector PostBox::extract_all_ready() { - std::lock_guard const lock(mutex_); std::vector ret; // Iterate through the outer map - auto pid_it = pigeonhole_.begin(); - while (pid_it != pigeonhole_.end()) { - // Iterate through the inner map - auto& chunks = pid_it->second; - auto chunk_it = chunks.begin(); - while (chunk_it != chunks.end()) { - if (chunk_it->second.is_ready()) { - ret.emplace_back(std::move(chunk_it->second)); - chunk_it = chunks.erase(chunk_it); + for (auto& [key, map_value] : pigeonhole_) { + std::lock_guard lock(map_value.mutex); + bool chunks_available = !map_value.chunks.empty(); + auto chunk_it = map_value.chunks.begin(); + while (chunk_it != map_value.chunks.end()) { + if (chunk_it->is_ready()) { + auto node = map_value.chunks.extract(chunk_it++); + ret.emplace_back(std::move(node.value())); } else { ++chunk_it; } } - // Remove the pid entry if its chunks map is empty - if (chunks.empty()) { - pid_it = pigeonhole_.erase(pid_it); - } else { - ++pid_it; + // if the chunks were available and are now empty, its fully extracted + if (chunks_available && map_value.chunks.empty()) { + RAPIDSMPF_EXPECTS( + n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0, + "PostBox.extract_all_ready(): n_non_empty_keys_ is already 0" + ); } } - return ret; } template bool PostBox::empty() const { - std::lock_guard const lock(mutex_); - return pigeonhole_.empty(); + return n_non_empty_keys_.load(std::memory_order_acquire) == 0; } template -std::vector> PostBox::search( - MemoryType mem_type -) const { - std::lock_guard const lock(mutex_); - std::vector> ret; - for (auto& [key, chunks] : pigeonhole_) { - for (auto& [cid, chunk] : chunks) { - if (!chunk.is_control_message(0) && chunk.data_memory_type() == mem_type) { - ret.emplace_back(key, cid, chunk.concat_data_size()); - } - } - } - return ret; +size_t PostBox::spill(BufferResource* /* br */, size_t /* amount */) { + // TODO: implement spill + return 0; } template @@ -113,14 +117,14 @@ std::string PostBox::str() const { } std::stringstream ss; ss << "PostBox("; - for (auto const& [key, chunks] : pigeonhole_) { + for (auto const& [key, map_value] : pigeonhole_) { ss << "k=" << key << ": ["; - for (auto const& [cid, chunk] : chunks) { - assert(cid == chunk.chunk_id()); + for (auto const& chunk : map_value.chunks) { + // assert(cid == chunk.chunk_id()); if (chunk.is_control_message(0)) { ss << "EOP" << chunk.expected_num_chunks(0) << ", "; } else { - ss << cid << ", "; + ss << chunk.chunk_id() << ", "; } } ss << "\b\b], "; diff --git a/cpp/src/shuffler/shuffler.cpp b/cpp/src/shuffler/shuffler.cpp index 9ccf3468e..4fcc748a2 100644 --- a/cpp/src/shuffler/shuffler.cpp +++ b/cpp/src/shuffler/shuffler.cpp @@ -85,69 +85,6 @@ std::unique_ptr allocate_buffer( return ret; } -/** - * @brief Spills memory buffers within a postbox, e.g., from device to host memory. - * - * This function moves a specified amount of memory from device to host storage - * or another lower-priority memory space, helping manage limited GPU memory - * by offloading excess data. - * - * The spilling is stream-ordered on the individual CUDA stream of each spilled buffer. - * - * @note While spilling, chunks are temporarily extracted from the postbox thus other - * threads trying to extract a chunk that is in the process of being spilled, will fail. - * To avoid this, the Shuffler uses `outbox_spillling_mutex_` to serialize extractions. - * - * @param br Buffer resource for GPU data allocations. - * @param log A logger for recording events and debugging information. - * @param statistics The statistics instance to use. - * @param stream CUDA stream to use for memory and kernel operations. - * @param amount The maximum amount of data (in bytes) to be spilled. - * - * @return The actual amount of data successfully spilled from the postbox. - * - * @warning This may temporarily empty the postbox, causing emptiness checks to return - * true even though the postbox is not actually empty. As a result, in the current - * implementation `postbox_spilling()` must not be used to spill `outgoing_postbox_`. - */ -template -std::size_t postbox_spilling( - BufferResource* br, - Communicator::Logger& log, - PostBox& postbox, - std::size_t amount -) { - RAPIDSMPF_NVTX_FUNC_RANGE(); - // Let's look for chunks to spill in the outbox. - auto const chunk_info = postbox.search(MemoryType::DEVICE); - std::size_t total_spilled{0}; - for (auto [pid, cid, size] : chunk_info) { - if (size == 0) { // skip empty data buffers - continue; - } - - // TODO: Use a clever strategy to decide which chunks to spill. For now, we - // just spill the chunks in an arbitrary order. - auto [host_reservation, host_overbooking] = - br->reserve(MemoryType::HOST, size, true); - if (host_overbooking > 0) { - log.warn( - "Cannot spill to host because of host memory overbooking: ", - format_nbytes(host_overbooking) - ); - continue; - } - // We extract the chunk, spilled it, and insert it back into the PostBox. - auto chunk = postbox.extract(pid, cid); - chunk.set_data_buffer(br->move(chunk.release_data_buffer(), host_reservation)); - postbox.insert(std::move(chunk)); - if ((total_spilled += size) >= amount) { - break; - } - } - return total_spilled; -} - } // namespace class Shuffler::Progress { @@ -476,20 +413,21 @@ Shuffler::Shuffler( : total_num_partitions{total_num_partitions}, partition_owner{std::move(partition_owner_fn)}, br_{br}, + comm_{std::move(comm)}, + local_partitions_{local_partitions(comm_, total_num_partitions, partition_owner)}, outgoing_postbox_{ [this](PartID pid) -> Rank { return this->partition_owner(this->comm_, pid); }, // extract Rank from pid - static_cast(comm->nranks()) + std::views::iota(Rank(0), Rank(this->comm_->nranks())) }, - ready_postbox_{ - [](PartID pid) -> PartID { return pid; }, // identity mapping - static_cast(total_num_partitions), + ready_postbox_{// [](PartID pid) -> PartID { return pid; }, // identity mapping + // static_cast(total_num_partitions), + std::identity{}, + local_partitions_ }, - comm_{std::move(comm)}, progress_thread_{std::move(progress_thread)}, op_id_{op_id}, - local_partitions_{local_partitions(comm_, total_num_partitions, partition_owner)}, finish_counter_{comm_->nranks(), local_partitions_, std::move(finished_callback)}, statistics_{std::move(statistics)} { RAPIDSMPF_EXPECTS(comm_ != nullptr, "the communicator pointer cannot be NULL"); @@ -768,7 +706,7 @@ void Shuffler::insert_finished(std::vector&& pids) { std::vector Shuffler::extract(PartID pid) { RAPIDSMPF_NVTX_FUNC_RANGE(); - std::unique_lock lock(ready_postbox_spilling_mutex_); + // std::unique_lock lock(ready_postbox_spilling_mutex_); // Quick return if the partition is empty. if (ready_postbox_.is_empty(pid)) { @@ -776,13 +714,13 @@ std::vector Shuffler::extract(PartID pid) { } auto chunks = ready_postbox_.extract(pid); - lock.unlock(); + // lock.unlock(); std::vector ret; ret.reserve(chunks.size()); - std::ranges::transform(chunks, std::back_inserter(ret), [](auto&& p) -> PackedData { - return {p.second.release_metadata_buffer(), p.second.release_data_buffer()}; + std::ranges::transform(chunks, std::back_inserter(ret), [](auto&& c) -> PackedData { + return {c.release_metadata_buffer(), c.release_data_buffer()}; }); return ret; @@ -815,8 +753,7 @@ std::size_t Shuffler::spill(std::optional amount) { } std::size_t spilled{0}; if (spill_need > 0) { - std::lock_guard lock(ready_postbox_spilling_mutex_); - spilled = postbox_spilling(br_, comm_->logger(), ready_postbox_, spill_need); + spilled = ready_postbox_.spill(br_, spill_need); } return spilled; } diff --git a/cpp/tests/test_shuffler.cpp b/cpp/tests/test_shuffler.cpp index 4a511f1a4..c54061af4 100644 --- a/cpp/tests/test_shuffler.cpp +++ b/cpp/tests/test_shuffler.cpp @@ -570,7 +570,7 @@ class ShuffleInsertGroupedTest } auto chunks = shuffler.outgoing_postbox_.extract_by_key(rank); - for (auto& [cid, chunk] : chunks) { + for (auto& chunk : chunks) { for (size_t i = 0; i < chunk.n_messages(); ++i) { outbound_chunks[chunk.part_id(i)]++; if (chunk.is_control_message(i)) { @@ -594,7 +594,7 @@ class ShuffleInsertGroupedTest // control messages outbound_chunks[pid]++; n_control_messages++; - for (auto& [cid, chunk] : local_chunks) { + for (auto& chunk : local_chunks) { for (size_t i = 0; i < chunk.n_messages(); ++i) { outbound_chunks[chunk.part_id(i)]++; @@ -1023,7 +1023,7 @@ class PostBoxTest : public cudf::test::BaseFixture { [this](rapidsmpf::shuffler::PartID part_id) { return partition_owner(part_id); }, - GlobalEnvironment->comm_->nranks() + std::views::iota(0, GlobalEnvironment->comm_->nranks()) ); } @@ -1068,9 +1068,11 @@ TEST_F(PostBoxTest, InsertAndExtractMultipleChunks) { auto chunks = postbox->extract_by_key(rank); extracted_nchunks += chunks.size(); - for (auto& [_, chunk] : chunks) { - extracted_chunks.emplace_back(std::move(chunk)); - } + extracted_chunks.insert( + extracted_chunks.end(), + std::make_move_iterator(chunks.begin()), + std::make_move_iterator(chunks.end()) + ); } EXPECT_EQ(extracted_nchunks, num_chunks); EXPECT_TRUE(postbox->empty()); From 670068b6d0c6bed2669413de1e14744fcb074b33 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 21 Nov 2025 16:27:50 -0800 Subject: [PATCH 4/9] add spilling Signed-off-by: niranda perera --- cpp/include/rapidsmpf/shuffler/chunk.hpp | 41 ---------- cpp/include/rapidsmpf/shuffler/postbox.hpp | 8 +- cpp/src/shuffler/postbox.cpp | 90 ++++++++++++++++------ cpp/src/shuffler/shuffler.cpp | 2 +- 4 files changed, 69 insertions(+), 72 deletions(-) diff --git a/cpp/include/rapidsmpf/shuffler/chunk.hpp b/cpp/include/rapidsmpf/shuffler/chunk.hpp index a243eb3cf..a8cc92384 100644 --- a/cpp/include/rapidsmpf/shuffler/chunk.hpp +++ b/cpp/include/rapidsmpf/shuffler/chunk.hpp @@ -469,44 +469,3 @@ inline std::ostream& operator<<(std::ostream& os, ReadyForDataMessage const& obj } // namespace detail } // namespace rapidsmpf::shuffler - -// Custom hash function for Chunk that uses chunk ID -namespace std { -template <> -/** - * @brief Hash function for Chunk. - */ -struct hash { - /** - * @brief Hash function for Chunk that uses chunk ID. - * - * @param chunk The chunk to hash. - * @return The hash of the chunk. - */ - std::size_t operator()( - rapidsmpf::shuffler::detail::Chunk const& chunk - ) const noexcept { - return std::hash{}(chunk.chunk_id()); - } -}; - -template <> -/** - * @brief Equality operator for Chunk. - */ -struct equal_to { - /** - * @brief Equality operator for Chunk that uses chunk ID. - * - * @param lhs The left chunk. - * @param rhs The right chunk. - * @return True if the chunks are equal, false otherwise. - */ - bool operator()( - rapidsmpf::shuffler::detail::Chunk const& lhs, - rapidsmpf::shuffler::detail::Chunk const& rhs - ) const noexcept { - return lhs.chunk_id() == rhs.chunk_id(); - } -}; -} // namespace std diff --git a/cpp/include/rapidsmpf/shuffler/postbox.hpp b/cpp/include/rapidsmpf/shuffler/postbox.hpp index 5450b0ade..0e2369e8f 100644 --- a/cpp/include/rapidsmpf/shuffler/postbox.hpp +++ b/cpp/include/rapidsmpf/shuffler/postbox.hpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include @@ -115,20 +114,19 @@ class PostBox { * @brief Spills the specified amount of data from the PostBox. * * @param br Buffer resource to use for spilling. + * @param log Logger to use for logging. * @param amount The amount of data to spill. * @return The amount of data spilled. */ - size_t spill(BufferResource* br, size_t amount); + size_t spill(BufferResource* br, Communicator::Logger& log, size_t amount); private: /** * @brief Map value for the PostBox. - * - * @note The mutex is used to protect the chunks set. */ struct MapValue { mutable std::mutex mutex; ///< Mutex to protect each key - std::unordered_set chunks; ///< Set of chunks for the key + std::vector chunks; ///< Vector of chunks for the key }; std::function diff --git a/cpp/src/shuffler/postbox.cpp b/cpp/src/shuffler/postbox.cpp index 2d2f5265b..084083a8c 100644 --- a/cpp/src/shuffler/postbox.cpp +++ b/cpp/src/shuffler/postbox.cpp @@ -3,9 +3,12 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include +#include #include #include +#include #include #include #include @@ -32,10 +35,7 @@ void PostBox::insert(Chunk&& chunk) { "PostBox.insert(): n_non_empty_keys_ is already at the maximum" ); } - RAPIDSMPF_EXPECTS( - map_value.chunks.emplace(std::move(chunk)).second, - "PostBox.insert(): chunk already exist" - ); + map_value.chunks.push_back(std::move(chunk)); } template @@ -53,15 +53,11 @@ std::vector PostBox::extract(PartID pid) { template std::vector PostBox::extract_by_key(KeyType key) { auto& map_value = pigeonhole_.at(key); - std::vector ret; std::lock_guard lock(map_value.mutex); RAPIDSMPF_EXPECTS(!map_value.chunks.empty(), "PostBox.extract(): partition is empty"); - ret.reserve(map_value.chunks.size()); - for (auto it = map_value.chunks.begin(); it != map_value.chunks.end();) { - auto node = map_value.chunks.extract(it++); - ret.emplace_back(std::move(node.value())); - } + std::vector ret = std::move(map_value.chunks); + map_value.chunks.clear(); RAPIDSMPF_EXPECTS( n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0, @@ -77,24 +73,33 @@ std::vector PostBox::extract_all_ready() { // Iterate through the outer map for (auto& [key, map_value] : pigeonhole_) { std::lock_guard lock(map_value.mutex); - bool chunks_available = !map_value.chunks.empty(); - auto chunk_it = map_value.chunks.begin(); - while (chunk_it != map_value.chunks.end()) { - if (chunk_it->is_ready()) { - auto node = map_value.chunks.extract(chunk_it++); - ret.emplace_back(std::move(node.value())); - } else { - ++chunk_it; - } - } - // if the chunks were available and are now empty, its fully extracted - if (chunks_available && map_value.chunks.empty()) { + // Partition: non-ready chunks first, ready chunks at the end + auto partition_point = + std::ranges::partition(map_value.chunks, [](const Chunk& c) { + return !c.is_ready(); + }).begin(); + + // if the chunks are available and all are ready, then all chunks will be + // extracted + if (map_value.chunks.begin() == partition_point + && partition_point != map_value.chunks.end()) + { RAPIDSMPF_EXPECTS( n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0, "PostBox.extract_all_ready(): n_non_empty_keys_ is already 0" ); } + + // Move ready chunks to result + ret.insert( + ret.end(), + std::make_move_iterator(partition_point), + std::make_move_iterator(map_value.chunks.end()) + ); + + // Remove ready chunks from the vector + map_value.chunks.erase(partition_point, map_value.chunks.end()); } return ret; } @@ -105,9 +110,44 @@ bool PostBox::empty() const { } template -size_t PostBox::spill(BufferResource* /* br */, size_t /* amount */) { - // TODO: implement spill - return 0; +size_t PostBox::spill( + BufferResource* br, Communicator::Logger& log, size_t amount +) { + RAPIDSMPF_NVTX_FUNC_RANGE(); + + // individually lock each key and spill the chunks in it. If we are unable to lock the + // key, then it will be skipped. + size_t total_spilled = 0; + for (auto& [key, map_value] : pigeonhole_) { + std::unique_lock lock(map_value.mutex, std::try_to_lock); + if (lock) { // now all chunks in this key are locked + for (auto& chunk : map_value.chunks) { + if (chunk.is_data_buffer_set() + && chunk.data_memory_type() == MemoryType::DEVICE) + { + size_t size = chunk.concat_data_size(); + auto [host_reservation, host_overbooking] = + br->reserve(MemoryType::HOST, size, true); + if (host_overbooking > 0) { + log.warn( + "Cannot spill to host because of host memory overbooking: ", + format_nbytes(host_overbooking) + ); + continue; + } + chunk.set_data_buffer( + br->move(chunk.release_data_buffer(), host_reservation) + ); + total_spilled += size; + if (total_spilled >= amount) { + break; + } + } + } + } + } + + return total_spilled; } template diff --git a/cpp/src/shuffler/shuffler.cpp b/cpp/src/shuffler/shuffler.cpp index 4fcc748a2..18044af6b 100644 --- a/cpp/src/shuffler/shuffler.cpp +++ b/cpp/src/shuffler/shuffler.cpp @@ -753,7 +753,7 @@ std::size_t Shuffler::spill(std::optional amount) { } std::size_t spilled{0}; if (spill_need > 0) { - spilled = ready_postbox_.spill(br_, spill_need); + spilled = ready_postbox_.spill(br_, comm_->logger(), spill_need); } return spilled; } From 47ea8c1bb6022770055d284704db50adfce22fe5 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 24 Nov 2025 07:08:44 -0800 Subject: [PATCH 5/9] minor changes Signed-off-by: niranda perera --- cpp/include/rapidsmpf/shuffler/postbox.hpp | 2 +- cpp/src/shuffler/postbox.cpp | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/include/rapidsmpf/shuffler/postbox.hpp b/cpp/include/rapidsmpf/shuffler/postbox.hpp index 0e2369e8f..c26a18986 100644 --- a/cpp/include/rapidsmpf/shuffler/postbox.hpp +++ b/cpp/include/rapidsmpf/shuffler/postbox.hpp @@ -68,7 +68,7 @@ class PostBox { * @note The result reflects a snapshot at the time of the call and may change * immediately afterward. */ - bool is_empty(PartID pid) const; + [[nodiscard]] bool is_empty(PartID pid) const; /** * @brief Extracts all chunks associated with a specific partition. diff --git a/cpp/src/shuffler/postbox.cpp b/cpp/src/shuffler/postbox.cpp index 084083a8c..0ac3e9ee2 100644 --- a/cpp/src/shuffler/postbox.cpp +++ b/cpp/src/shuffler/postbox.cpp @@ -72,7 +72,10 @@ std::vector PostBox::extract_all_ready() { // Iterate through the outer map for (auto& [key, map_value] : pigeonhole_) { - std::lock_guard lock(map_value.mutex); + std::unique_lock lock(map_value.mutex, std::try_to_lock); + if (!lock.owns_lock()) { + continue; + } // Partition: non-ready chunks first, ready chunks at the end auto partition_point = From 966d3d398416ca8787047234d38f4896c767162d Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 24 Nov 2025 11:42:40 -0800 Subject: [PATCH 6/9] cosmetic changes Signed-off-by: niranda perera --- cpp/include/rapidsmpf/shuffler/postbox.hpp | 5 ++- cpp/src/shuffler/postbox.cpp | 47 ++++++++++++---------- cpp/src/shuffler/shuffler.cpp | 8 +--- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/cpp/include/rapidsmpf/shuffler/postbox.hpp b/cpp/include/rapidsmpf/shuffler/postbox.hpp index c26a18986..b89c393f5 100644 --- a/cpp/include/rapidsmpf/shuffler/postbox.hpp +++ b/cpp/include/rapidsmpf/shuffler/postbox.hpp @@ -132,7 +132,10 @@ class PostBox { std::function key_map_fn_; ///< Function to map partition IDs to keys. std::unordered_map pigeonhole_; ///< Storage for chunks - std::atomic n_non_empty_keys_{0}; + std::atomic n_non_empty_keys_{ + 0 + }; ///< Number of non-empty keys. Since the pigenhole map is not extracted, this + ///< count will be used to check the emptiness }; /** diff --git a/cpp/src/shuffler/postbox.cpp b/cpp/src/shuffler/postbox.cpp index 0ac3e9ee2..6385dd6b0 100644 --- a/cpp/src/shuffler/postbox.cpp +++ b/cpp/src/shuffler/postbox.cpp @@ -25,10 +25,11 @@ void PostBox::insert(Chunk&& chunk) { "PostBox.insert(): all messages in the chunk must map to the same key" ); } - // std::lock_guard const lock(mutex_); + auto& map_value = pigeonhole_.at(key); std::lock_guard lock(map_value.mutex); if (map_value.chunks.empty()) { + // this key is currently empty. So increment the non-empty key count. RAPIDSMPF_EXPECTS( n_non_empty_keys_.fetch_add(1, std::memory_order_relaxed) + 1 <= pigeonhole_.size(), @@ -123,28 +124,30 @@ size_t PostBox::spill( size_t total_spilled = 0; for (auto& [key, map_value] : pigeonhole_) { std::unique_lock lock(map_value.mutex, std::try_to_lock); - if (lock) { // now all chunks in this key are locked - for (auto& chunk : map_value.chunks) { - if (chunk.is_data_buffer_set() - && chunk.data_memory_type() == MemoryType::DEVICE) - { - size_t size = chunk.concat_data_size(); - auto [host_reservation, host_overbooking] = - br->reserve(MemoryType::HOST, size, true); - if (host_overbooking > 0) { - log.warn( - "Cannot spill to host because of host memory overbooking: ", - format_nbytes(host_overbooking) - ); - continue; - } - chunk.set_data_buffer( - br->move(chunk.release_data_buffer(), host_reservation) + if (!lock) { // skip to the next key + continue; + } + + for (auto& chunk : map_value.chunks) { + if (chunk.is_data_buffer_set() + && chunk.data_memory_type() == MemoryType::DEVICE) + { + size_t size = chunk.concat_data_size(); + auto [host_reservation, host_overbooking] = + br->reserve(MemoryType::HOST, size, true); + if (host_overbooking > 0) { + log.warn( + "Cannot spill to host because of host memory overbooking: ", + format_nbytes(host_overbooking) ); - total_spilled += size; - if (total_spilled >= amount) { - break; - } + continue; + } + chunk.set_data_buffer( + br->move(chunk.release_data_buffer(), host_reservation) + ); + total_spilled += size; + if (total_spilled >= amount) { + break; } } } diff --git a/cpp/src/shuffler/shuffler.cpp b/cpp/src/shuffler/shuffler.cpp index 18044af6b..3d2a7cc67 100644 --- a/cpp/src/shuffler/shuffler.cpp +++ b/cpp/src/shuffler/shuffler.cpp @@ -421,11 +421,7 @@ Shuffler::Shuffler( }, // extract Rank from pid std::views::iota(Rank(0), Rank(this->comm_->nranks())) }, - ready_postbox_{// [](PartID pid) -> PartID { return pid; }, // identity mapping - // static_cast(total_num_partitions), - std::identity{}, - local_partitions_ - }, + ready_postbox_{std::identity{}, local_partitions_}, progress_thread_{std::move(progress_thread)}, op_id_{op_id}, finish_counter_{comm_->nranks(), local_partitions_, std::move(finished_callback)}, @@ -706,7 +702,6 @@ void Shuffler::insert_finished(std::vector&& pids) { std::vector Shuffler::extract(PartID pid) { RAPIDSMPF_NVTX_FUNC_RANGE(); - // std::unique_lock lock(ready_postbox_spilling_mutex_); // Quick return if the partition is empty. if (ready_postbox_.is_empty(pid)) { @@ -714,7 +709,6 @@ std::vector Shuffler::extract(PartID pid) { } auto chunks = ready_postbox_.extract(pid); - // lock.unlock(); std::vector ret; ret.reserve(chunks.size()); From bfc1fe2fd609110bc68b0d29c29ee02adf2b80ca Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 24 Nov 2025 15:17:46 -0800 Subject: [PATCH 7/9] adding spiiling & insert Signed-off-by: niranda perera --- cpp/src/shuffler/postbox.cpp | 4 +- cpp/tests/test_shuffler.cpp | 172 ++++++++++++++++++++++++++++++++++- 2 files changed, 173 insertions(+), 3 deletions(-) diff --git a/cpp/src/shuffler/postbox.cpp b/cpp/src/shuffler/postbox.cpp index 6385dd6b0..96a41e86b 100644 --- a/cpp/src/shuffler/postbox.cpp +++ b/cpp/src/shuffler/postbox.cpp @@ -164,7 +164,7 @@ std::string PostBox::str() const { std::stringstream ss; ss << "PostBox("; for (auto const& [key, map_value] : pigeonhole_) { - ss << "k=" << key << ": ["; + ss << "k=" << key << ":["; for (auto const& chunk : map_value.chunks) { // assert(cid == chunk.chunk_id()); if (chunk.is_control_message(0)) { @@ -173,7 +173,7 @@ std::string PostBox::str() const { ss << chunk.chunk_id() << ", "; } } - ss << "\b\b], "; + ss << (map_value.chunks.empty() ? "], " : "\b\b], "); } ss << "\b\b)"; return ss.str(); diff --git a/cpp/tests/test_shuffler.cpp b/cpp/tests/test_shuffler.cpp index c54061af4..b7876b886 100644 --- a/cpp/tests/test_shuffler.cpp +++ b/cpp/tests/test_shuffler.cpp @@ -1012,7 +1012,7 @@ Chunk make_dummy_chunk(ChunkID chunk_id, PartID part_id) { } } // namespace rapidsmpf::shuffler::detail -class PostBoxTest : public cudf::test::BaseFixture { +class PostBoxTest : public ::testing::Test { protected: using PostboxType = rapidsmpf::shuffler::detail::PostBox; @@ -1294,3 +1294,173 @@ TEST(ShufflerTest, multiple_shutdowns) { shuffler.reset(); GlobalEnvironment->barrier(); } + +// Parameterized test for PostBox with concurrent insert/extract threads +class PostBoxMultithreadedTest + : public ::testing::TestWithParam> { + protected: + static constexpr size_t chunk_size = 1024; // 1KB + + void SetUp() override { + std::tie(num_threads, num_keys, chunks_per_key) = GetParam(); + stream = cudf::get_default_stream(); + + // Create buffer resource with chunk_size*num_keys/10 device memory limit + int64_t device_memory_limit = (num_keys * chunk_size) / 10; + + auto mr_ptr = cudf::get_current_device_resource_ref(); + mr = std::make_unique(mr_ptr); + + br = std::make_unique( + mr_ptr, + std::unordered_map< + rapidsmpf::MemoryType, + rapidsmpf::BufferResource::MemoryAvailable>{ + {rapidsmpf::MemoryType::DEVICE, + rapidsmpf::LimitAvailableMemory(mr.get(), device_memory_limit)} + }, + std::nullopt // disable periodic spill check + ); + + // Create PostBox with identity mapping: keys are [0, num_keys) + postbox = std::make_unique< + rapidsmpf::shuffler::detail::PostBox>( + std::identity{}, std::views::iota(0u, num_keys) + ); + + num_extracted_chunks.store(0, std::memory_order_relaxed); + + spill_function_id = br->spill_manager().add_spill_function( + [this](size_t amount) -> size_t { + return postbox->spill( + br.get(), GlobalEnvironment->comm_->logger(), amount + ); + }, + 0 + ); + } + + void TearDown() override { + postbox.reset(); + br->spill_manager().remove_spill_function(spill_function_id); + br.reset(); + mr.reset(); + } + + uint32_t num_threads; + uint32_t num_keys; + uint32_t chunks_per_key; + rmm::cuda_stream_view stream; + + std::unique_ptr mr; + std::unique_ptr br; + std::unique_ptr> + postbox; + rapidsmpf::SpillManager::SpillFunctionID spill_function_id; +}; + +TEST_P(PostBoxMultithreadedTest, ConcurrentInsertExtract) { + std::atomic chunk_id_counter{0}; + std::atomic completed_insert_threads{0}; + std::atomic num_extracted_chunks{0}; + std::vector> insert_futures; + std::vector> extract_futures; + + auto gen_keys = [this](uint32_t tid) { + std::unordered_set keys; + keys.reserve(num_keys / num_threads); + for (uint32_t key = tid; key < num_keys; key += num_threads) { + keys.insert(key); + } + return keys; + }; + + // insert & spill threads + for (uint32_t tid = 0; tid < num_threads; ++tid) { + insert_futures.emplace_back(std::async(std::launch::async, [&, tid] { + for (uint32_t p = 0; p < chunks_per_key; ++p) { + for (auto key : gen_keys(tid)) { + // Create chunk with 1KB device buffer and 1KB host buffer + auto chunk_id = + chunk_id_counter.fetch_add(1, std::memory_order_relaxed); + + // Reserve 1KB allocation using buffer resource reserve and spill + auto reservation = br->reserve_and_spill( + rapidsmpf::MemoryType::DEVICE, chunk_size, true + ); + + auto metadata = std::make_unique>( + chunk_size, static_cast(key) + ); + auto buffer = br->allocate(stream, std::move(reservation)); + + postbox->insert( + rapidsmpf::shuffler::detail::Chunk::from_packed_data( + chunk_id, key, {std::move(metadata), std::move(buffer)} + ) + ); + } + } + // Signal that this insert thread is done + completed_insert_threads.fetch_add(1, std::memory_order_release); + })); + } + + auto insert_done = [&]() { + return completed_insert_threads.load(std::memory_order_acquire) == num_threads; + }; + + // extract threads + for (uint32_t tid = 0; tid < num_threads; ++tid) { + extract_futures.emplace_back(std::async(std::launch::async, [&, tid] { + auto keys = gen_keys(tid); + + while (!keys.empty()) { // extact untill all keys are empty + for (auto it = keys.begin(); it != keys.end();) { + if (!postbox->is_empty(*it)) { + auto chunks = postbox->extract(*it); + auto chunk_count = static_cast(chunks.size()); + num_extracted_chunks.fetch_add( + chunk_count, std::memory_order_relaxed + ); + it++; + } else if (insert_done() && postbox->is_empty(*it)) { + it = keys.erase(it); + } else { + it++; + } + } + } + })); + } + + for (auto& future : insert_futures) { + EXPECT_NO_THROW(future.get()); + } + + for (auto& future : extract_futures) { + EXPECT_NO_THROW(future.get()); + } + + // Verify that all chunks were extracted + uint32_t expected_total_chunks = num_keys * chunks_per_key; + SCOPED_TRACE("postbox: " + postbox->str()); + EXPECT_EQ(expected_total_chunks, chunk_id_counter.load()); + EXPECT_EQ(expected_total_chunks, num_extracted_chunks.load()); + EXPECT_TRUE(postbox->empty()); +} + +INSTANTIATE_TEST_SUITE_P( + PostBoxConcurrent, + PostBoxMultithreadedTest, + testing::Combine( + testing::Values(1, 2), // num_threads + testing::Values(10, 100), // num_keys + testing::Values(1, 10) // chunks_per_key + ), + [](const testing::TestParamInfo& info) { + return "nthreads_" + std::to_string(std::get<0>(info.param)) + "_keys_" + + std::to_string(std::get<1>(info.param)) + "_chunks_" + + std::to_string(std::get<2>(info.param)); + } +); From 1a94262ddda64b3562c2ae1121b162f9d7f60f00 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 24 Nov 2025 15:33:58 -0800 Subject: [PATCH 8/9] minor bug Signed-off-by: niranda perera --- cpp/tests/test_shuffler.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/tests/test_shuffler.cpp b/cpp/tests/test_shuffler.cpp index b7876b886..3de3f4a51 100644 --- a/cpp/tests/test_shuffler.cpp +++ b/cpp/tests/test_shuffler.cpp @@ -1328,8 +1328,6 @@ class PostBoxMultithreadedTest std::identity{}, std::views::iota(0u, num_keys) ); - num_extracted_chunks.store(0, std::memory_order_relaxed); - spill_function_id = br->spill_manager().add_spill_function( [this](size_t amount) -> size_t { return postbox->spill( From 9ca5fdc52c7f39431292e475a653fce20d8c2c6b Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 26 Nov 2025 14:33:51 -0800 Subject: [PATCH 9/9] WIP Signed-off-by: niranda perera --- cpp/include/rapidsmpf/shuffler/postbox.hpp | 21 +- cpp/src/memory/spill_manager.cpp | 14 +- cpp/src/shuffler/postbox.cpp | 243 ++++++++++++++++----- 3 files changed, 214 insertions(+), 64 deletions(-) diff --git a/cpp/include/rapidsmpf/shuffler/postbox.hpp b/cpp/include/rapidsmpf/shuffler/postbox.hpp index b89c393f5..859abc673 100644 --- a/cpp/include/rapidsmpf/shuffler/postbox.hpp +++ b/cpp/include/rapidsmpf/shuffler/postbox.hpp @@ -6,7 +6,9 @@ #include #include +#include #include +#include #include #include #include @@ -50,6 +52,8 @@ class PostBox { std::forward_as_tuple() ); } + rng_ = std::mt19937(std::random_device{}()); + dist_ = std::uniform_int_distribution(0, keys.size() - 1); } /** @@ -126,16 +130,23 @@ class PostBox { */ struct MapValue { mutable std::mutex mutex; ///< Mutex to protect each key - std::vector chunks; ///< Vector of chunks for the key + std::list ready_chunks; ///< Vector of chunks for the key + size_t n_spilling_chunks{0}; ///< Number of chunks that are being spilled + + [[nodiscard]] bool is_empty_unsafe() const noexcept { + return ready_chunks.empty() && n_spilling_chunks == 0; + } }; std::function key_map_fn_; ///< Function to map partition IDs to keys. std::unordered_map pigeonhole_; ///< Storage for chunks - std::atomic n_non_empty_keys_{ - 0 - }; ///< Number of non-empty keys. Since the pigenhole map is not extracted, this - ///< count will be used to check the emptiness + std::atomic n_chunks{0 + }; ///< Number of chunks in the PostBox. Since the pigenhole map is not extracted, + ///< this count will be used to check the emptiness + std::mt19937 rng_; ///< Random number generator + std::uniform_int_distribution + dist_; ///< Distribution for selecting a random key }; /** diff --git a/cpp/src/memory/spill_manager.cpp b/cpp/src/memory/spill_manager.cpp index 17bfe5e97..a08239209 100644 --- a/cpp/src/memory/spill_manager.cpp +++ b/cpp/src/memory/spill_manager.cpp @@ -69,14 +69,22 @@ std::size_t SpillManager::spill(std::size_t amount) { std::size_t spilled{0}; std::unique_lock lock(mutex_); auto const t0_elapsed = Clock::now(); - for (auto const [_, fid] : spill_function_priorities_) { + // for (auto const [_, fid] : spill_function_priorities_) { + // if (spilled >= amount) { + // break; + // } + // spilled += spill_functions_.at(fid)(amount - spilled); + // } + auto spill_functions_cp = spill_functions_; + lock.unlock(); + + for (auto& [id, fn] : spill_functions_cp) { if (spilled >= amount) { break; } - spilled += spill_functions_.at(fid)(amount - spilled); + spilled += fn(amount - spilled); } auto const t1_elapsed = Clock::now(); - lock.unlock(); auto& stats = *br_->statistics(); stats.add_duration_stat("spill-time-device-to-host", t1_elapsed - t0_elapsed); stats.add_bytes_stat("spill-bytes-device-to-host", spilled); diff --git a/cpp/src/shuffler/postbox.cpp b/cpp/src/shuffler/postbox.cpp index 96a41e86b..ae8bbde0d 100644 --- a/cpp/src/shuffler/postbox.cpp +++ b/cpp/src/shuffler/postbox.cpp @@ -28,26 +28,33 @@ void PostBox::insert(Chunk&& chunk) { auto& map_value = pigeonhole_.at(key); std::lock_guard lock(map_value.mutex); - if (map_value.chunks.empty()) { - // this key is currently empty. So increment the non-empty key count. - RAPIDSMPF_EXPECTS( - n_non_empty_keys_.fetch_add(1, std::memory_order_relaxed) + 1 - <= pigeonhole_.size(), - "PostBox.insert(): n_non_empty_keys_ is already at the maximum" - ); + // if (map_value.chunks.empty()) { + // // this key is currently empty. So increment the non-empty key count. + // RAPIDSMPF_EXPECTS( + // n_non_empty_keys_.fetch_add(1, std::memory_order_relaxed) + 1 + // <= pigeonhole_.size(), + // "PostBox.insert(): n_non_empty_keys_ is already at the maximum" + // ); + // } + n_chunks.fetch_add(1, std::memory_order_relaxed); + if (chunk.is_data_buffer_set() && chunk.data_memory_type() == MemoryType::HOST) { + map_value.ready_chunks.emplace_back(std::move(chunk)); + } else { + map_value.ready_chunks.emplace_front(std::move(chunk)); } - map_value.chunks.push_back(std::move(chunk)); } template bool PostBox::is_empty(PartID pid) const { auto& map_value = pigeonhole_.at(key_map_fn_(pid)); std::lock_guard lock(map_value.mutex); - return map_value.chunks.empty(); + // return map_value.chunks.empty(); + return map_value.is_empty_unsafe(); } template std::vector PostBox::extract(PartID pid) { + RAPIDSMPF_NVTX_FUNC_RANGE(); return extract_by_key(key_map_fn_(pid)); } @@ -55,14 +62,23 @@ template std::vector PostBox::extract_by_key(KeyType key) { auto& map_value = pigeonhole_.at(key); std::lock_guard lock(map_value.mutex); - RAPIDSMPF_EXPECTS(!map_value.chunks.empty(), "PostBox.extract(): partition is empty"); + RAPIDSMPF_EXPECTS( + !map_value.is_empty_unsafe(), "PostBox.extract(): partition is empty" + ); - std::vector ret = std::move(map_value.chunks); - map_value.chunks.clear(); + std::vector ret( + std::make_move_iterator(map_value.ready_chunks.begin()), + std::make_move_iterator(map_value.ready_chunks.end()) + ); + map_value.ready_chunks.clear(); + // RAPIDSMPF_EXPECTS( + // n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0, + // "PostBox.extract(): n_non_empty_keys_ is already 0" + // ); RAPIDSMPF_EXPECTS( - n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0, - "PostBox.extract(): n_non_empty_keys_ is already 0" + n_chunks.fetch_sub(ret.size(), std::memory_order_relaxed) >= ret.size(), + "PostBox.extract(): n_chunks is negative" ); return ret; } @@ -78,79 +94,194 @@ std::vector PostBox::extract_all_ready() { continue; } - // Partition: non-ready chunks first, ready chunks at the end - auto partition_point = - std::ranges::partition(map_value.chunks, [](const Chunk& c) { - return !c.is_ready(); - }).begin(); + // // Partition: non-ready chunks first, ready chunks at the end + // auto partition_point = + // std::ranges::partition(map_value.chunks, [](const Chunk& c) { + // return !c.is_ready(); + // }).begin(); - // if the chunks are available and all are ready, then all chunks will be - // extracted - if (map_value.chunks.begin() == partition_point - && partition_point != map_value.chunks.end()) - { - RAPIDSMPF_EXPECTS( - n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0, - "PostBox.extract_all_ready(): n_non_empty_keys_ is already 0" - ); - } + // // if the chunks are available and all are ready, then all chunks will be + // // extracted + // if (map_value.chunks.begin() == partition_point + // && partition_point != map_value.chunks.end()) + // { + // RAPIDSMPF_EXPECTS( + // n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0, + // "PostBox.extract_all_ready(): n_non_empty_keys_ is already 0" + // ); + // } - // Move ready chunks to result - ret.insert( - ret.end(), - std::make_move_iterator(partition_point), - std::make_move_iterator(map_value.chunks.end()) - ); + // // Move ready chunks to result + // ret.insert( + // ret.end(), + // std::make_move_iterator(partition_point), + // std::make_move_iterator(map_value.chunks.end()) + // ); + + // // Remove ready chunks from the vector + // map_value.chunks.erase(partition_point, map_value.chunks.end()); - // Remove ready chunks from the vector - map_value.chunks.erase(partition_point, map_value.chunks.end()); + for (auto it = map_value.ready_chunks.begin(); + it != map_value.ready_chunks.end();) + { + if (it->is_ready()) { + ret.emplace_back(std::move(*it)); + it = map_value.ready_chunks.erase(it); + RAPIDSMPF_EXPECTS( + n_chunks.fetch_sub(1, std::memory_order_relaxed) >= 1, + "PostBox.extract_all_ready(): n_chunks is negative" + ); + } else { + ++it; + } + } } return ret; } template bool PostBox::empty() const { - return n_non_empty_keys_.load(std::memory_order_acquire) == 0; + return n_chunks.load(std::memory_order_acquire) == 0; } template size_t PostBox::spill( - BufferResource* br, Communicator::Logger& log, size_t amount + BufferResource* br, Communicator::Logger& /* log */, size_t amount ) { - RAPIDSMPF_NVTX_FUNC_RANGE(); + RAPIDSMPF_NVTX_SCOPED_RANGE("spill-inside-postbox"); // individually lock each key and spill the chunks in it. If we are unable to lock the // key, then it will be skipped. size_t total_spilled = 0; + + // auto it_start = pigeonhole_.begin(); + // std::advance(it_start, dist_(rng_)); + + + // auto spill_chunks = [&](auto& map_value) { + // std::unique_lock lock(map_value.mutex, std::try_to_lock); + + // if (!lock) { + // return false; + // } + + // for (auto& chunk : map_value.chunks) { + // if (chunk.is_data_buffer_set() + // && chunk.data_memory_type() == MemoryType::DEVICE) + // { + // size_t size = chunk.concat_data_size(); + // auto [host_reservation, host_overbooking] = + // br->reserve(MemoryType::HOST, size, true); + // if (host_overbooking > 0) { + // log.warn( + // "Cannot spill to host because of host memory overbooking: ", + // format_nbytes(host_overbooking) + // ); + // continue; + // } + // chunk.set_data_buffer( + // br->move(chunk.release_data_buffer(), host_reservation) + // ); + // total_spilled += size; + // if (total_spilled >= amount) { + // return true; + // } + // } + // } + // return false; + // }; + + // for (auto it = it_start; it != pigeonhole_.end(); ++it) { + // auto& [key, map_value] = *it; + // // std::unique_lock lock(map_value.mutex, std::try_to_lock); + // // if (!lock) { // skip to the next key + // // continue; + // // } + + // // for (auto& chunk : map_value.chunks) { + // // if (chunk.is_data_buffer_set() + // // && chunk.data_memory_type() == MemoryType::DEVICE) + // // { + // // size_t size = chunk.concat_data_size(); + // // auto [host_reservation, host_overbooking] = + // // br->reserve(MemoryType::HOST, size, true); + // // if (host_overbooking > 0) { + // // log.warn( + // // "Cannot spill to host because of host memory overbooking: ", + // // format_nbytes(host_overbooking) + // // ); + // // continue; + // // } + // // chunk.set_data_buffer( + // // br->move(chunk.release_data_buffer(), host_reservation) + // // ); + // // total_spilled += size; + // // if (total_spilled >= amount) { + // // break; + // // } + // // } + // // } + // if (spill_chunks(map_value)) { + // break; + // } + // } + + // for (auto it = pigeonhole_.begin(); it != it_start; ++it) { + // auto& [key, map_value] = *it; + // if (spill_chunks(map_value)) { + // break; + // } + // } + for (auto& [key, map_value] : pigeonhole_) { std::unique_lock lock(map_value.mutex, std::try_to_lock); if (!lock) { // skip to the next key continue; } - for (auto& chunk : map_value.chunks) { + std::vector spillable_chunks; + for (auto it = map_value.ready_chunks.begin(); + it != map_value.ready_chunks.end();) + { + auto& chunk = *it; if (chunk.is_data_buffer_set() && chunk.data_memory_type() == MemoryType::DEVICE) { size_t size = chunk.concat_data_size(); - auto [host_reservation, host_overbooking] = - br->reserve(MemoryType::HOST, size, true); - if (host_overbooking > 0) { - log.warn( - "Cannot spill to host because of host memory overbooking: ", - format_nbytes(host_overbooking) - ); - continue; - } - chunk.set_data_buffer( - br->move(chunk.release_data_buffer(), host_reservation) - ); + spillable_chunks.emplace_back(std::move(chunk)); + it = map_value.ready_chunks.erase(it); total_spilled += size; if (total_spilled >= amount) { break; } + } else { + ++it; } } + map_value.n_spilling_chunks += spillable_chunks.size(); + // release lock + lock.unlock(); + + // spill the chunks to host memory + while (!spillable_chunks.empty()) { + auto chunk = std::move(spillable_chunks.back()); + spillable_chunks.pop_back(); + size_t size = chunk.concat_data_size(); + auto [host_reservation, host_overbooking] = + br->reserve(MemoryType::HOST, size, true); + RAPIDSMPF_EXPECTS( + host_overbooking == 0, + "Cannot spill to host because of host memory overbooking: " + + std::to_string(host_overbooking) + ); + chunk.set_data_buffer(br->move(chunk.release_data_buffer(), host_reservation) + ); + + lock.lock(); + map_value.ready_chunks.emplace_back(std::move(chunk)); + map_value.n_spilling_chunks--; + lock.unlock(); + } } return total_spilled; @@ -164,8 +295,8 @@ std::string PostBox::str() const { std::stringstream ss; ss << "PostBox("; for (auto const& [key, map_value] : pigeonhole_) { - ss << "k=" << key << ":["; - for (auto const& chunk : map_value.chunks) { + ss << "k=" << key << " nspill=" << map_value.n_spilling_chunks << ":["; + for (auto const& chunk : map_value.ready_chunks) { // assert(cid == chunk.chunk_id()); if (chunk.is_control_message(0)) { ss << "EOP" << chunk.expected_num_chunks(0) << ", "; @@ -173,7 +304,7 @@ std::string PostBox::str() const { ss << chunk.chunk_id() << ", "; } } - ss << (map_value.chunks.empty() ? "], " : "\b\b], "); + ss << (map_value.ready_chunks.empty() ? "], " : "\b\b], "); } ss << "\b\b)"; return ss.str();