diff --git a/cpp/include/rapidsmpf/shuffler/postbox.hpp b/cpp/include/rapidsmpf/shuffler/postbox.hpp index 48876d949..859abc673 100644 --- a/cpp/include/rapidsmpf/shuffler/postbox.hpp +++ b/cpp/include/rapidsmpf/shuffler/postbox.hpp @@ -4,13 +4,18 @@ */ #pragma once +#include #include +#include #include +#include +#include #include #include #include #include +#include #include namespace rapidsmpf::shuffler::detail { @@ -31,17 +36,24 @@ class PostBox { * * @tparam Fn The type of the function that maps a partition ID to a key. * @param key_map_fn A function that maps a partition ID to a key. - * @param num_keys_hint The number of keys to reserve space for. + * @param keys The keys expected to be used in the PostBox. * * @note The `key_map_fn` must be convertible to a function that takes a `PartID` and * returns a `KeyType`. */ - template - PostBox(Fn&& key_map_fn, size_t num_keys_hint = 0) - : key_map_fn_(std::move(key_map_fn)) { - if (num_keys_hint > 0) { - pigeonhole_.reserve(num_keys_hint); + template + requires std::convertible_to, KeyType> + PostBox(Fn&& key_map_fn, Range&& keys) : key_map_fn_(std::move(key_map_fn)) { + pigeonhole_.reserve(std::ranges::size(keys)); + for (const auto& key : keys) { + pigeonhole_.emplace( + std::piecewise_construct, + std::forward_as_tuple(key), + std::forward_as_tuple() + ); } + rng_ = std::mt19937(std::random_device{}()); + dist_ = std::uniform_int_distribution(0, keys.size() - 1); } /** @@ -60,18 +72,7 @@ class PostBox { * @note The result reflects a snapshot at the time of the call and may change * immediately afterward. */ - bool is_empty(PartID pid) const; - - /** - * @brief Extracts a specific chunk from the PostBox. - * - * @param pid The ID of the partition containing the chunk. - * @param cid The ID of the chunk to be accessed. - * @return The extracted chunk. - * - * @throws std::out_of_range If the chunk is not found. - */ - [[nodiscard]] Chunk extract(PartID pid, ChunkID cid); + [[nodiscard]] bool is_empty(PartID pid) const; /** * @brief Extracts all chunks associated with a specific partition. @@ -81,7 +82,7 @@ class PostBox { * * @throws std::out_of_range If the partition is not found. */ - std::unordered_map extract(PartID pid); + std::vector extract(PartID pid); /** * @brief Extracts all chunks associated with a specific key. @@ -91,7 +92,7 @@ class PostBox { * * @throws std::out_of_range If the key is not found. */ - std::unordered_map extract_by_key(KeyType key); + std::vector extract_by_key(KeyType key); /** * @brief Extracts all ready chunks from the PostBox. @@ -107,30 +108,45 @@ class PostBox { */ [[nodiscard]] bool empty() const; - /** - * @brief Searches for chunks of the specified memory type. - * - * @param mem_type The type of memory to search within. - * @return A vector of tuples, where each tuple contains: PartID, ChunkID, and the - * size of the chunk. - */ - [[nodiscard]] std::vector> search( - MemoryType mem_type - ) const; - /** * @brief Returns a description of this instance. * @return The description. */ [[nodiscard]] std::string str() const; + /** + * @brief Spills the specified amount of data from the PostBox. + * + * @param br Buffer resource to use for spilling. + * @param log Logger to use for logging. + * @param amount The amount of data to spill. + * @return The amount of data spilled. + */ + size_t spill(BufferResource* br, Communicator::Logger& log, size_t amount); + private: - // TODO: more fine-grained locking e.g. by locking each partition individually. - mutable std::mutex mutex_; + /** + * @brief Map value for the PostBox. + */ + struct MapValue { + mutable std::mutex mutex; ///< Mutex to protect each key + std::list ready_chunks; ///< Vector of chunks for the key + size_t n_spilling_chunks{0}; ///< Number of chunks that are being spilled + + [[nodiscard]] bool is_empty_unsafe() const noexcept { + return ready_chunks.empty() && n_spilling_chunks == 0; + } + }; + std::function key_map_fn_; ///< Function to map partition IDs to keys. - std::unordered_map> - pigeonhole_; ///< Storage for chunks, organized by a key and chunk ID. + std::unordered_map pigeonhole_; ///< Storage for chunks + std::atomic n_chunks{0 + }; ///< Number of chunks in the PostBox. Since the pigenhole map is not extracted, + ///< this count will be used to check the emptiness + std::mt19937 rng_; ///< Random number generator + std::uniform_int_distribution + dist_; ///< Distribution for selecting a random key }; /** diff --git a/cpp/include/rapidsmpf/shuffler/shuffler.hpp b/cpp/include/rapidsmpf/shuffler/shuffler.hpp index 4e3b66653..e3a063a42 100644 --- a/cpp/include/rapidsmpf/shuffler/shuffler.hpp +++ b/cpp/include/rapidsmpf/shuffler/shuffler.hpp @@ -341,29 +341,26 @@ class Shuffler { private: BufferResource* br_; + std::shared_ptr comm_; std::atomic active_{true}; + std::vector const local_partitions_; + detail::PostBox outgoing_postbox_; ///< Postbox for outgoing chunks, that are ///< ready to be sent to other ranks. detail::PostBox ready_postbox_; ///< Postbox for received chunks, that are ///< ready to be extracted by the user. - std::shared_ptr comm_; std::shared_ptr progress_thread_; ProgressThread::FunctionID progress_thread_function_id_; OpID const op_id_; SpillManager::SpillFunctionID spill_function_id_; - std::vector const local_partitions_; detail::FinishCounter finish_counter_; std::unordered_map outbound_chunk_counter_; mutable std::mutex outbound_chunk_counter_mutex_; - // We protect ready_postbox extraction to avoid returning a chunk that is in the - // process of being spilled by `Shuffler::spill`. - mutable std::mutex ready_postbox_spilling_mutex_; - std::atomic chunk_id_counter_{0}; std::shared_ptr statistics_; diff --git a/cpp/src/memory/spill_manager.cpp b/cpp/src/memory/spill_manager.cpp index 17bfe5e97..a08239209 100644 --- a/cpp/src/memory/spill_manager.cpp +++ b/cpp/src/memory/spill_manager.cpp @@ -69,14 +69,22 @@ std::size_t SpillManager::spill(std::size_t amount) { std::size_t spilled{0}; std::unique_lock lock(mutex_); auto const t0_elapsed = Clock::now(); - for (auto const [_, fid] : spill_function_priorities_) { + // for (auto const [_, fid] : spill_function_priorities_) { + // if (spilled >= amount) { + // break; + // } + // spilled += spill_functions_.at(fid)(amount - spilled); + // } + auto spill_functions_cp = spill_functions_; + lock.unlock(); + + for (auto& [id, fn] : spill_functions_cp) { if (spilled >= amount) { break; } - spilled += spill_functions_.at(fid)(amount - spilled); + spilled += fn(amount - spilled); } auto const t1_elapsed = Clock::now(); - lock.unlock(); auto& stats = *br_->statistics(); stats.add_duration_stat("spill-time-device-to-host", t1_elapsed - t0_elapsed); stats.add_bytes_stat("spill-bytes-device-to-host", spilled); diff --git a/cpp/src/shuffler/postbox.cpp b/cpp/src/shuffler/postbox.cpp index 4450ba517..ae8bbde0d 100644 --- a/cpp/src/shuffler/postbox.cpp +++ b/cpp/src/shuffler/postbox.cpp @@ -3,9 +3,12 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include +#include #include #include +#include #include #include #include @@ -22,88 +25,266 @@ void PostBox::insert(Chunk&& chunk) { "PostBox.insert(): all messages in the chunk must map to the same key" ); } - std::lock_guard const lock(mutex_); - RAPIDSMPF_EXPECTS( - pigeonhole_[key].emplace(chunk.chunk_id(), std::move(chunk)).second, - "PostBox.insert(): chunk already exist" - ); + + auto& map_value = pigeonhole_.at(key); + std::lock_guard lock(map_value.mutex); + // if (map_value.chunks.empty()) { + // // this key is currently empty. So increment the non-empty key count. + // RAPIDSMPF_EXPECTS( + // n_non_empty_keys_.fetch_add(1, std::memory_order_relaxed) + 1 + // <= pigeonhole_.size(), + // "PostBox.insert(): n_non_empty_keys_ is already at the maximum" + // ); + // } + n_chunks.fetch_add(1, std::memory_order_relaxed); + if (chunk.is_data_buffer_set() && chunk.data_memory_type() == MemoryType::HOST) { + map_value.ready_chunks.emplace_back(std::move(chunk)); + } else { + map_value.ready_chunks.emplace_front(std::move(chunk)); + } } template bool PostBox::is_empty(PartID pid) const { - std::lock_guard const lock(mutex_); - return !pigeonhole_.contains(key_map_fn_(pid)); + auto& map_value = pigeonhole_.at(key_map_fn_(pid)); + std::lock_guard lock(map_value.mutex); + // return map_value.chunks.empty(); + return map_value.is_empty_unsafe(); } template -Chunk PostBox::extract(PartID pid, ChunkID cid) { - std::lock_guard const lock(mutex_); - return extract_item(pigeonhole_[key_map_fn_(pid)], cid).second; +std::vector PostBox::extract(PartID pid) { + RAPIDSMPF_NVTX_FUNC_RANGE(); + return extract_by_key(key_map_fn_(pid)); } template -std::unordered_map PostBox::extract(PartID pid) { - std::lock_guard const lock(mutex_); - return extract_value(pigeonhole_, key_map_fn_(pid)); -} +std::vector PostBox::extract_by_key(KeyType key) { + auto& map_value = pigeonhole_.at(key); + std::lock_guard lock(map_value.mutex); + RAPIDSMPF_EXPECTS( + !map_value.is_empty_unsafe(), "PostBox.extract(): partition is empty" + ); -template -std::unordered_map PostBox::extract_by_key(KeyType key) { - std::lock_guard const lock(mutex_); - return extract_value(pigeonhole_, key); + std::vector ret( + std::make_move_iterator(map_value.ready_chunks.begin()), + std::make_move_iterator(map_value.ready_chunks.end()) + ); + map_value.ready_chunks.clear(); + + // RAPIDSMPF_EXPECTS( + // n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0, + // "PostBox.extract(): n_non_empty_keys_ is already 0" + // ); + RAPIDSMPF_EXPECTS( + n_chunks.fetch_sub(ret.size(), std::memory_order_relaxed) >= ret.size(), + "PostBox.extract(): n_chunks is negative" + ); + return ret; } template std::vector PostBox::extract_all_ready() { - std::lock_guard const lock(mutex_); std::vector ret; // Iterate through the outer map - auto pid_it = pigeonhole_.begin(); - while (pid_it != pigeonhole_.end()) { - // Iterate through the inner map - auto& chunks = pid_it->second; - auto chunk_it = chunks.begin(); - while (chunk_it != chunks.end()) { - if (chunk_it->second.is_ready()) { - ret.emplace_back(std::move(chunk_it->second)); - chunk_it = chunks.erase(chunk_it); - } else { - ++chunk_it; - } + for (auto& [key, map_value] : pigeonhole_) { + std::unique_lock lock(map_value.mutex, std::try_to_lock); + if (!lock.owns_lock()) { + continue; } - // Remove the pid entry if its chunks map is empty - if (chunks.empty()) { - pid_it = pigeonhole_.erase(pid_it); - } else { - ++pid_it; + // // Partition: non-ready chunks first, ready chunks at the end + // auto partition_point = + // std::ranges::partition(map_value.chunks, [](const Chunk& c) { + // return !c.is_ready(); + // }).begin(); + + // // if the chunks are available and all are ready, then all chunks will be + // // extracted + // if (map_value.chunks.begin() == partition_point + // && partition_point != map_value.chunks.end()) + // { + // RAPIDSMPF_EXPECTS( + // n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0, + // "PostBox.extract_all_ready(): n_non_empty_keys_ is already 0" + // ); + // } + + // // Move ready chunks to result + // ret.insert( + // ret.end(), + // std::make_move_iterator(partition_point), + // std::make_move_iterator(map_value.chunks.end()) + // ); + + // // Remove ready chunks from the vector + // map_value.chunks.erase(partition_point, map_value.chunks.end()); + + for (auto it = map_value.ready_chunks.begin(); + it != map_value.ready_chunks.end();) + { + if (it->is_ready()) { + ret.emplace_back(std::move(*it)); + it = map_value.ready_chunks.erase(it); + RAPIDSMPF_EXPECTS( + n_chunks.fetch_sub(1, std::memory_order_relaxed) >= 1, + "PostBox.extract_all_ready(): n_chunks is negative" + ); + } else { + ++it; + } } } - return ret; } template bool PostBox::empty() const { - std::lock_guard const lock(mutex_); - return pigeonhole_.empty(); + return n_chunks.load(std::memory_order_acquire) == 0; } template -std::vector> PostBox::search( - MemoryType mem_type -) const { - std::lock_guard const lock(mutex_); - std::vector> ret; - for (auto& [key, chunks] : pigeonhole_) { - for (auto& [cid, chunk] : chunks) { - if (!chunk.is_control_message(0) && chunk.data_memory_type() == mem_type) { - ret.emplace_back(key, cid, chunk.concat_data_size()); +size_t PostBox::spill( + BufferResource* br, Communicator::Logger& /* log */, size_t amount +) { + RAPIDSMPF_NVTX_SCOPED_RANGE("spill-inside-postbox"); + + // individually lock each key and spill the chunks in it. If we are unable to lock the + // key, then it will be skipped. + size_t total_spilled = 0; + + // auto it_start = pigeonhole_.begin(); + // std::advance(it_start, dist_(rng_)); + + + // auto spill_chunks = [&](auto& map_value) { + // std::unique_lock lock(map_value.mutex, std::try_to_lock); + + // if (!lock) { + // return false; + // } + + // for (auto& chunk : map_value.chunks) { + // if (chunk.is_data_buffer_set() + // && chunk.data_memory_type() == MemoryType::DEVICE) + // { + // size_t size = chunk.concat_data_size(); + // auto [host_reservation, host_overbooking] = + // br->reserve(MemoryType::HOST, size, true); + // if (host_overbooking > 0) { + // log.warn( + // "Cannot spill to host because of host memory overbooking: ", + // format_nbytes(host_overbooking) + // ); + // continue; + // } + // chunk.set_data_buffer( + // br->move(chunk.release_data_buffer(), host_reservation) + // ); + // total_spilled += size; + // if (total_spilled >= amount) { + // return true; + // } + // } + // } + // return false; + // }; + + // for (auto it = it_start; it != pigeonhole_.end(); ++it) { + // auto& [key, map_value] = *it; + // // std::unique_lock lock(map_value.mutex, std::try_to_lock); + // // if (!lock) { // skip to the next key + // // continue; + // // } + + // // for (auto& chunk : map_value.chunks) { + // // if (chunk.is_data_buffer_set() + // // && chunk.data_memory_type() == MemoryType::DEVICE) + // // { + // // size_t size = chunk.concat_data_size(); + // // auto [host_reservation, host_overbooking] = + // // br->reserve(MemoryType::HOST, size, true); + // // if (host_overbooking > 0) { + // // log.warn( + // // "Cannot spill to host because of host memory overbooking: ", + // // format_nbytes(host_overbooking) + // // ); + // // continue; + // // } + // // chunk.set_data_buffer( + // // br->move(chunk.release_data_buffer(), host_reservation) + // // ); + // // total_spilled += size; + // // if (total_spilled >= amount) { + // // break; + // // } + // // } + // // } + // if (spill_chunks(map_value)) { + // break; + // } + // } + + // for (auto it = pigeonhole_.begin(); it != it_start; ++it) { + // auto& [key, map_value] = *it; + // if (spill_chunks(map_value)) { + // break; + // } + // } + + for (auto& [key, map_value] : pigeonhole_) { + std::unique_lock lock(map_value.mutex, std::try_to_lock); + if (!lock) { // skip to the next key + continue; + } + + std::vector spillable_chunks; + for (auto it = map_value.ready_chunks.begin(); + it != map_value.ready_chunks.end();) + { + auto& chunk = *it; + if (chunk.is_data_buffer_set() + && chunk.data_memory_type() == MemoryType::DEVICE) + { + size_t size = chunk.concat_data_size(); + spillable_chunks.emplace_back(std::move(chunk)); + it = map_value.ready_chunks.erase(it); + total_spilled += size; + if (total_spilled >= amount) { + break; + } + } else { + ++it; } } + map_value.n_spilling_chunks += spillable_chunks.size(); + // release lock + lock.unlock(); + + // spill the chunks to host memory + while (!spillable_chunks.empty()) { + auto chunk = std::move(spillable_chunks.back()); + spillable_chunks.pop_back(); + size_t size = chunk.concat_data_size(); + auto [host_reservation, host_overbooking] = + br->reserve(MemoryType::HOST, size, true); + RAPIDSMPF_EXPECTS( + host_overbooking == 0, + "Cannot spill to host because of host memory overbooking: " + + std::to_string(host_overbooking) + ); + chunk.set_data_buffer(br->move(chunk.release_data_buffer(), host_reservation) + ); + + lock.lock(); + map_value.ready_chunks.emplace_back(std::move(chunk)); + map_value.n_spilling_chunks--; + lock.unlock(); + } } - return ret; + + return total_spilled; } template @@ -113,17 +294,17 @@ std::string PostBox::str() const { } std::stringstream ss; ss << "PostBox("; - for (auto const& [key, chunks] : pigeonhole_) { - ss << "k=" << key << ": ["; - for (auto const& [cid, chunk] : chunks) { - assert(cid == chunk.chunk_id()); + for (auto const& [key, map_value] : pigeonhole_) { + ss << "k=" << key << " nspill=" << map_value.n_spilling_chunks << ":["; + for (auto const& chunk : map_value.ready_chunks) { + // assert(cid == chunk.chunk_id()); if (chunk.is_control_message(0)) { ss << "EOP" << chunk.expected_num_chunks(0) << ", "; } else { - ss << cid << ", "; + ss << chunk.chunk_id() << ", "; } } - ss << "\b\b], "; + ss << (map_value.ready_chunks.empty() ? "], " : "\b\b], "); } ss << "\b\b)"; return ss.str(); diff --git a/cpp/src/shuffler/shuffler.cpp b/cpp/src/shuffler/shuffler.cpp index 9ccf3468e..3d2a7cc67 100644 --- a/cpp/src/shuffler/shuffler.cpp +++ b/cpp/src/shuffler/shuffler.cpp @@ -85,69 +85,6 @@ std::unique_ptr allocate_buffer( return ret; } -/** - * @brief Spills memory buffers within a postbox, e.g., from device to host memory. - * - * This function moves a specified amount of memory from device to host storage - * or another lower-priority memory space, helping manage limited GPU memory - * by offloading excess data. - * - * The spilling is stream-ordered on the individual CUDA stream of each spilled buffer. - * - * @note While spilling, chunks are temporarily extracted from the postbox thus other - * threads trying to extract a chunk that is in the process of being spilled, will fail. - * To avoid this, the Shuffler uses `outbox_spillling_mutex_` to serialize extractions. - * - * @param br Buffer resource for GPU data allocations. - * @param log A logger for recording events and debugging information. - * @param statistics The statistics instance to use. - * @param stream CUDA stream to use for memory and kernel operations. - * @param amount The maximum amount of data (in bytes) to be spilled. - * - * @return The actual amount of data successfully spilled from the postbox. - * - * @warning This may temporarily empty the postbox, causing emptiness checks to return - * true even though the postbox is not actually empty. As a result, in the current - * implementation `postbox_spilling()` must not be used to spill `outgoing_postbox_`. - */ -template -std::size_t postbox_spilling( - BufferResource* br, - Communicator::Logger& log, - PostBox& postbox, - std::size_t amount -) { - RAPIDSMPF_NVTX_FUNC_RANGE(); - // Let's look for chunks to spill in the outbox. - auto const chunk_info = postbox.search(MemoryType::DEVICE); - std::size_t total_spilled{0}; - for (auto [pid, cid, size] : chunk_info) { - if (size == 0) { // skip empty data buffers - continue; - } - - // TODO: Use a clever strategy to decide which chunks to spill. For now, we - // just spill the chunks in an arbitrary order. - auto [host_reservation, host_overbooking] = - br->reserve(MemoryType::HOST, size, true); - if (host_overbooking > 0) { - log.warn( - "Cannot spill to host because of host memory overbooking: ", - format_nbytes(host_overbooking) - ); - continue; - } - // We extract the chunk, spilled it, and insert it back into the PostBox. - auto chunk = postbox.extract(pid, cid); - chunk.set_data_buffer(br->move(chunk.release_data_buffer(), host_reservation)); - postbox.insert(std::move(chunk)); - if ((total_spilled += size) >= amount) { - break; - } - } - return total_spilled; -} - } // namespace class Shuffler::Progress { @@ -476,20 +413,17 @@ Shuffler::Shuffler( : total_num_partitions{total_num_partitions}, partition_owner{std::move(partition_owner_fn)}, br_{br}, + comm_{std::move(comm)}, + local_partitions_{local_partitions(comm_, total_num_partitions, partition_owner)}, outgoing_postbox_{ [this](PartID pid) -> Rank { return this->partition_owner(this->comm_, pid); }, // extract Rank from pid - static_cast(comm->nranks()) + std::views::iota(Rank(0), Rank(this->comm_->nranks())) }, - ready_postbox_{ - [](PartID pid) -> PartID { return pid; }, // identity mapping - static_cast(total_num_partitions), - }, - comm_{std::move(comm)}, + ready_postbox_{std::identity{}, local_partitions_}, progress_thread_{std::move(progress_thread)}, op_id_{op_id}, - local_partitions_{local_partitions(comm_, total_num_partitions, partition_owner)}, finish_counter_{comm_->nranks(), local_partitions_, std::move(finished_callback)}, statistics_{std::move(statistics)} { RAPIDSMPF_EXPECTS(comm_ != nullptr, "the communicator pointer cannot be NULL"); @@ -768,7 +702,6 @@ void Shuffler::insert_finished(std::vector&& pids) { std::vector Shuffler::extract(PartID pid) { RAPIDSMPF_NVTX_FUNC_RANGE(); - std::unique_lock lock(ready_postbox_spilling_mutex_); // Quick return if the partition is empty. if (ready_postbox_.is_empty(pid)) { @@ -776,13 +709,12 @@ std::vector Shuffler::extract(PartID pid) { } auto chunks = ready_postbox_.extract(pid); - lock.unlock(); std::vector ret; ret.reserve(chunks.size()); - std::ranges::transform(chunks, std::back_inserter(ret), [](auto&& p) -> PackedData { - return {p.second.release_metadata_buffer(), p.second.release_data_buffer()}; + std::ranges::transform(chunks, std::back_inserter(ret), [](auto&& c) -> PackedData { + return {c.release_metadata_buffer(), c.release_data_buffer()}; }); return ret; @@ -815,8 +747,7 @@ std::size_t Shuffler::spill(std::optional amount) { } std::size_t spilled{0}; if (spill_need > 0) { - std::lock_guard lock(ready_postbox_spilling_mutex_); - spilled = postbox_spilling(br_, comm_->logger(), ready_postbox_, spill_need); + spilled = ready_postbox_.spill(br_, comm_->logger(), spill_need); } return spilled; } diff --git a/cpp/tests/test_shuffler.cpp b/cpp/tests/test_shuffler.cpp index 4a511f1a4..3de3f4a51 100644 --- a/cpp/tests/test_shuffler.cpp +++ b/cpp/tests/test_shuffler.cpp @@ -570,7 +570,7 @@ class ShuffleInsertGroupedTest } auto chunks = shuffler.outgoing_postbox_.extract_by_key(rank); - for (auto& [cid, chunk] : chunks) { + for (auto& chunk : chunks) { for (size_t i = 0; i < chunk.n_messages(); ++i) { outbound_chunks[chunk.part_id(i)]++; if (chunk.is_control_message(i)) { @@ -594,7 +594,7 @@ class ShuffleInsertGroupedTest // control messages outbound_chunks[pid]++; n_control_messages++; - for (auto& [cid, chunk] : local_chunks) { + for (auto& chunk : local_chunks) { for (size_t i = 0; i < chunk.n_messages(); ++i) { outbound_chunks[chunk.part_id(i)]++; @@ -1012,7 +1012,7 @@ Chunk make_dummy_chunk(ChunkID chunk_id, PartID part_id) { } } // namespace rapidsmpf::shuffler::detail -class PostBoxTest : public cudf::test::BaseFixture { +class PostBoxTest : public ::testing::Test { protected: using PostboxType = rapidsmpf::shuffler::detail::PostBox; @@ -1023,7 +1023,7 @@ class PostBoxTest : public cudf::test::BaseFixture { [this](rapidsmpf::shuffler::PartID part_id) { return partition_owner(part_id); }, - GlobalEnvironment->comm_->nranks() + std::views::iota(0, GlobalEnvironment->comm_->nranks()) ); } @@ -1068,9 +1068,11 @@ TEST_F(PostBoxTest, InsertAndExtractMultipleChunks) { auto chunks = postbox->extract_by_key(rank); extracted_nchunks += chunks.size(); - for (auto& [_, chunk] : chunks) { - extracted_chunks.emplace_back(std::move(chunk)); - } + extracted_chunks.insert( + extracted_chunks.end(), + std::make_move_iterator(chunks.begin()), + std::make_move_iterator(chunks.end()) + ); } EXPECT_EQ(extracted_nchunks, num_chunks); EXPECT_TRUE(postbox->empty()); @@ -1292,3 +1294,171 @@ TEST(ShufflerTest, multiple_shutdowns) { shuffler.reset(); GlobalEnvironment->barrier(); } + +// Parameterized test for PostBox with concurrent insert/extract threads +class PostBoxMultithreadedTest + : public ::testing::TestWithParam> { + protected: + static constexpr size_t chunk_size = 1024; // 1KB + + void SetUp() override { + std::tie(num_threads, num_keys, chunks_per_key) = GetParam(); + stream = cudf::get_default_stream(); + + // Create buffer resource with chunk_size*num_keys/10 device memory limit + int64_t device_memory_limit = (num_keys * chunk_size) / 10; + + auto mr_ptr = cudf::get_current_device_resource_ref(); + mr = std::make_unique(mr_ptr); + + br = std::make_unique( + mr_ptr, + std::unordered_map< + rapidsmpf::MemoryType, + rapidsmpf::BufferResource::MemoryAvailable>{ + {rapidsmpf::MemoryType::DEVICE, + rapidsmpf::LimitAvailableMemory(mr.get(), device_memory_limit)} + }, + std::nullopt // disable periodic spill check + ); + + // Create PostBox with identity mapping: keys are [0, num_keys) + postbox = std::make_unique< + rapidsmpf::shuffler::detail::PostBox>( + std::identity{}, std::views::iota(0u, num_keys) + ); + + spill_function_id = br->spill_manager().add_spill_function( + [this](size_t amount) -> size_t { + return postbox->spill( + br.get(), GlobalEnvironment->comm_->logger(), amount + ); + }, + 0 + ); + } + + void TearDown() override { + postbox.reset(); + br->spill_manager().remove_spill_function(spill_function_id); + br.reset(); + mr.reset(); + } + + uint32_t num_threads; + uint32_t num_keys; + uint32_t chunks_per_key; + rmm::cuda_stream_view stream; + + std::unique_ptr mr; + std::unique_ptr br; + std::unique_ptr> + postbox; + rapidsmpf::SpillManager::SpillFunctionID spill_function_id; +}; + +TEST_P(PostBoxMultithreadedTest, ConcurrentInsertExtract) { + std::atomic chunk_id_counter{0}; + std::atomic completed_insert_threads{0}; + std::atomic num_extracted_chunks{0}; + std::vector> insert_futures; + std::vector> extract_futures; + + auto gen_keys = [this](uint32_t tid) { + std::unordered_set keys; + keys.reserve(num_keys / num_threads); + for (uint32_t key = tid; key < num_keys; key += num_threads) { + keys.insert(key); + } + return keys; + }; + + // insert & spill threads + for (uint32_t tid = 0; tid < num_threads; ++tid) { + insert_futures.emplace_back(std::async(std::launch::async, [&, tid] { + for (uint32_t p = 0; p < chunks_per_key; ++p) { + for (auto key : gen_keys(tid)) { + // Create chunk with 1KB device buffer and 1KB host buffer + auto chunk_id = + chunk_id_counter.fetch_add(1, std::memory_order_relaxed); + + // Reserve 1KB allocation using buffer resource reserve and spill + auto reservation = br->reserve_and_spill( + rapidsmpf::MemoryType::DEVICE, chunk_size, true + ); + + auto metadata = std::make_unique>( + chunk_size, static_cast(key) + ); + auto buffer = br->allocate(stream, std::move(reservation)); + + postbox->insert( + rapidsmpf::shuffler::detail::Chunk::from_packed_data( + chunk_id, key, {std::move(metadata), std::move(buffer)} + ) + ); + } + } + // Signal that this insert thread is done + completed_insert_threads.fetch_add(1, std::memory_order_release); + })); + } + + auto insert_done = [&]() { + return completed_insert_threads.load(std::memory_order_acquire) == num_threads; + }; + + // extract threads + for (uint32_t tid = 0; tid < num_threads; ++tid) { + extract_futures.emplace_back(std::async(std::launch::async, [&, tid] { + auto keys = gen_keys(tid); + + while (!keys.empty()) { // extact untill all keys are empty + for (auto it = keys.begin(); it != keys.end();) { + if (!postbox->is_empty(*it)) { + auto chunks = postbox->extract(*it); + auto chunk_count = static_cast(chunks.size()); + num_extracted_chunks.fetch_add( + chunk_count, std::memory_order_relaxed + ); + it++; + } else if (insert_done() && postbox->is_empty(*it)) { + it = keys.erase(it); + } else { + it++; + } + } + } + })); + } + + for (auto& future : insert_futures) { + EXPECT_NO_THROW(future.get()); + } + + for (auto& future : extract_futures) { + EXPECT_NO_THROW(future.get()); + } + + // Verify that all chunks were extracted + uint32_t expected_total_chunks = num_keys * chunks_per_key; + SCOPED_TRACE("postbox: " + postbox->str()); + EXPECT_EQ(expected_total_chunks, chunk_id_counter.load()); + EXPECT_EQ(expected_total_chunks, num_extracted_chunks.load()); + EXPECT_TRUE(postbox->empty()); +} + +INSTANTIATE_TEST_SUITE_P( + PostBoxConcurrent, + PostBoxMultithreadedTest, + testing::Combine( + testing::Values(1, 2), // num_threads + testing::Values(10, 100), // num_keys + testing::Values(1, 10) // chunks_per_key + ), + [](const testing::TestParamInfo& info) { + return "nthreads_" + std::to_string(std::get<0>(info.param)) + "_keys_" + + std::to_string(std::get<1>(info.param)) + "_chunks_" + + std::to_string(std::get<2>(info.param)); + } +);