From 671cc1aeac9184a1ced68d607e75a305c8633f62 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 21 Jan 2026 11:21:19 +0000 Subject: [PATCH 01/14] Estimate adaptive join input sizes Sample a small number of chunks from each side to estimate input size. The goal is to make an early join-strategy choice without stalling the pipeline or buffering whole inputs. This adds an allgather so the estimate reflects all ranks while keeping the local sample size minimal. --- cpp/benchmarks/streaming/ndsh/join.cpp | 127 +++++++++++++++++++++++++ 1 file changed, 127 insertions(+) diff --git a/cpp/benchmarks/streaming/ndsh/join.cpp b/cpp/benchmarks/streaming/ndsh/join.cpp index 52b5f8f6c..53d995f64 100644 --- a/cpp/benchmarks/streaming/ndsh/join.cpp +++ b/cpp/benchmarks/streaming/ndsh/join.cpp @@ -6,7 +6,10 @@ #include "join.hpp" #include +#include +#include #include +#include #include #include @@ -25,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -403,4 +407,127 @@ streaming::Node shuffle( co_await ch_out->drain(ctx->executor()); } +namespace { +coro::task> allgather_join_sizes( + std::shared_ptr ctx, + OpID tag, + std::size_t left_sample_bytes, + std::size_t right_sample_bytes +) { + auto metadata = std::make_unique>(2 * sizeof(std::size_t)); + std::memcpy(metadata->data(), &left_sample_bytes, sizeof(left_sample_bytes)); + std::memcpy( + metadata->data() + sizeof(left_sample_bytes), + &right_sample_bytes, + sizeof(right_sample_bytes) + ); + + auto stream = ctx->br()->stream_pool().get_stream(); + auto [res, _] = ctx->br()->reserve(MemoryType::HOST, 0, true); + auto buf = ctx->br()->allocate(stream, std::move(res)); + auto allgather = streaming::AllGather(ctx, tag); + allgather.insert(0, {PackedData(std::move(metadata), std::move(buf))}); + allgather.insert_finished(); + auto per_rank = co_await allgather.extract_all(streaming::AllGather::Ordered::NO); + + std::size_t left_total_bytes = 0; + std::size_t right_total_bytes = 0; + for (auto const& data : per_rank) { + // Assumption: each rank packs two size_t values into metadata. + RAPIDSMPF_EXPECTS( + data.metadata->size() >= 2 * sizeof(std::size_t), + "Invalid metadata size for adaptive join size estimation" + ); + std::size_t rank_left = 0; + std::size_t rank_right = 0; + std::memcpy(&rank_left, data.metadata->data(), sizeof(rank_left)); + std::memcpy( + &rank_right, data.metadata->data() + sizeof(rank_left), sizeof(rank_right) + ); + left_total_bytes += rank_left; + right_total_bytes += rank_right; + } + co_return {left_total_bytes, right_total_bytes}; +} +} // namespace + +streaming::Node adaptive_inner_join( + std::shared_ptr ctx, + std::shared_ptr left, + std::shared_ptr right, + std::shared_ptr ch_out, + std::vector left_keys, + std::vector right_keys, + OpID allreduce_tag, + OpID left_shuffle_tag, + OpID right_shuffle_tag +) { + streaming::ShutdownAtExit c{left, right, ch_out}; + co_await ctx->executor()->schedule(); + (void)left_keys; + (void)right_keys; + (void)allreduce_tag; + (void)left_shuffle_tag; + (void)right_shuffle_tag; + // Assumption: the input channels carry only TableChunk messages. + // Assumption: summing data_alloc_size across memory types is a good proxy for the + // amount of data that will need to be materialized on device for compute. + // Assumption: a small sample of chunks is representative of the whole table size. + constexpr std::size_t inspect_messages = 2; + std::vector left_buffer; + std::vector right_buffer; + left_buffer.reserve(inspect_messages + 1); + right_buffer.reserve(inspect_messages + 1); + + auto inspect_channel = + [&](std::shared_ptr ch, + std::vector& buffer) -> coro::task { + std::size_t bytes = 0; + for (std::size_t i = 0; i < inspect_messages; ++i) { + auto msg = co_await ch->receive(); + if (msg.empty()) { + // Preserve the termination marker for downstream processing. + buffer.push_back(std::move(msg)); + co_return bytes; + } + RAPIDSMPF_EXPECTS( + msg.holds(), + "adaptive_inner_join expects TableChunk messages" + ); + auto const& chunk = msg.get(); + for (auto mem_type : MEMORY_TYPES) { + bytes += chunk.data_alloc_size(mem_type); + } + buffer.push_back(std::move(msg)); + } + co_return bytes; + }; + + auto left_sample_bytes = co_await inspect_channel(left, left_buffer); + auto right_sample_bytes = co_await inspect_channel(right, right_buffer); + + std::size_t left_total_bytes = left_sample_bytes; + std::size_t right_total_bytes = right_sample_bytes; + if (ctx->comm()->nranks() > 1) { + std::tie(left_total_bytes, right_total_bytes) = co_await allgather_join_sizes( + ctx, allreduce_tag, left_sample_bytes, right_sample_bytes + ); + } + + ctx->comm()->logger().print( + "Adaptive join sample sizes: left ", + left_sample_bytes, + " bytes, right ", + right_sample_bytes, + " bytes" + ); + ctx->comm()->logger().print( + "Adaptive join total sizes: left ", + left_total_bytes, + " bytes, right ", + right_total_bytes, + " bytes" + ); + co_await ch_out->drain(ctx->executor()); +} } // namespace rapidsmpf::ndsh From 43fd80b8b1efad575ddfe56810da126ec2c6172c Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 21 Jan 2026 11:30:48 +0000 Subject: [PATCH 02/14] WIP: add an adaptive join benchmark Exploring how to manage approximate channel message counts in an actor network. --- cpp/benchmarks/streaming/ndsh/CMakeLists.txt | 2 +- .../streaming/ndsh/adaptive_join.cpp | 509 ++++++++++++++++++ 2 files changed, 510 insertions(+), 1 deletion(-) create mode 100644 cpp/benchmarks/streaming/ndsh/adaptive_join.cpp diff --git a/cpp/benchmarks/streaming/ndsh/CMakeLists.txt b/cpp/benchmarks/streaming/ndsh/CMakeLists.txt index 5b9f2dc6e..cc5b453fc 100644 --- a/cpp/benchmarks/streaming/ndsh/CMakeLists.txt +++ b/cpp/benchmarks/streaming/ndsh/CMakeLists.txt @@ -41,7 +41,7 @@ target_link_libraries( $ maybe_asan ) -set(RAPIDSMPFNDSH_QUERIES q01 q03 q09 q21 bench_read) +set(RAPIDSMPFNDSH_QUERIES q01 q03 q09 q21 bench_read adaptive_join) foreach(query IN ITEMS ${RAPIDSMPFNDSH_QUERIES}) add_executable(${query} "${query}.cpp") diff --git a/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp b/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp new file mode 100644 index 000000000..63e258731 --- /dev/null +++ b/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp @@ -0,0 +1,509 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "utils.hpp" + +#include + +namespace { + +rapidsmpf::streaming::Node read_parquet( + std::shared_ptr ctx, + std::shared_ptr ch_out, + std::size_t num_producers, + cudf::size_type num_rows_per_chunk, + std::optional> columns, + std::string const& input_directory, + std::string const& input_file +) { + auto files = rapidsmpf::ndsh::detail::list_parquet_files( + rapidsmpf::ndsh::detail::get_table_path(input_directory, input_file) + ); + auto options = + cudf::io::parquet_reader_options::builder(cudf::io::source_info(files)).build(); + if (columns.has_value()) { + options.set_columns(*columns); + } + return rapidsmpf::streaming::node::read_parquet( + ctx, ch_out, num_producers, options, num_rows_per_chunk + ); +} + +std::size_t estimate_read_parquet_messages( + std::shared_ptr ctx, + std::string const& input_directory, + std::string const& input_file, + cudf::size_type num_rows_per_chunk +) { + auto files = rapidsmpf::ndsh::detail::list_parquet_files( + rapidsmpf::ndsh::detail::get_table_path(input_directory, input_file) + ); + if (files.empty()) { + return 0; + } + + // Assumption: this file-to-rank mapping matches read_parquet in the streaming node. + auto const rank = static_cast(ctx->comm()->rank()); + auto const size = static_cast(ctx->comm()->nranks()); + auto const base = files.size() / size; + auto const extra = files.size() % size; + auto const files_per_rank = base + (rank < extra ? 1 : 0); + auto const file_offset = rank * base + std::min(rank, extra); + if (files_per_rank == 0) { + return 0; + } + + std::size_t total_rows = 0; + for (std::size_t i = 0; i < files_per_rank; ++i) { + auto const& file = files[file_offset + i]; + total_rows += static_cast( + cudf::io::read_parquet_metadata(cudf::io::source_info(file)).num_rows() + ); + } + if (total_rows == 0 || num_rows_per_chunk <= 0) { + return 0; + } + + // Assumption: chunk sizes are close to num_rows_per_chunk and filters are absent. + auto const chunk_rows = static_cast(num_rows_per_chunk); + return (total_rows + chunk_rows - 1) / chunk_rows; +} + +rapidsmpf::streaming::Node advertise_message_count( + std::shared_ptr ctx, + std::shared_ptr ch_meta, + std::size_t estimate +) { + rapidsmpf::streaming::ShutdownAtExit c{ch_meta}; + co_await ctx->executor()->schedule(); + auto payload = std::make_unique(estimate); + co_await ch_meta->send(rapidsmpf::streaming::Message{0, std::move(payload), {}, {}}); + co_await ch_meta->drain(ctx->executor()); +} + +rapidsmpf::streaming::Node consume_message_count( + std::shared_ptr ctx, + std::shared_ptr ch_meta +) { + rapidsmpf::streaming::ShutdownAtExit c{ch_meta}; + co_await ctx->executor()->schedule(); + auto msg = co_await ch_meta->receive(); + if (!msg.empty()) { + RAPIDSMPF_EXPECTS( + msg.holds(), "Expected size_t message count estimate" + ); + ctx->comm()->logger().print("Estimated message count: ", msg.get()); + } + co_await ch_meta->drain(ctx->executor()); +} + +rapidsmpf::streaming::Node consume_channel_parallel( + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::size_t num_consumers +) { + rapidsmpf::streaming::ShutdownAtExit c{ch_in}; + std::atomic estimated_total_bytes{0}; + auto task = [&]() -> rapidsmpf::streaming::Node { + co_await ctx->executor()->schedule(); + while (true) { + auto msg = co_await ch_in->receive(); + if (msg.empty()) { + break; + } + if (msg.holds()) { + auto chunk = rapidsmpf::ndsh::to_device( + ctx, msg.release() + ); + ctx->comm()->logger().print( + "Consumed chunk with ", + chunk.table_view().num_rows(), + " rows and ", + chunk.table_view().num_columns(), + " columns" + ); + estimated_total_bytes.fetch_add( + chunk.data_alloc_size(rapidsmpf::MemoryType::DEVICE) + ); + } + } + }; + std::vector tasks; + for (std::size_t i = 0; i < num_consumers; i++) { + tasks.push_back(task()); + } + rapidsmpf::streaming::coro_results(co_await coro::when_all(std::move(tasks))); + ctx->comm()->logger().print( + "Table was around ", rmm::detail::format_bytes(estimated_total_bytes.load()) + ); +} + +///< @brief Configuration options for the benchmark +struct ProgramOptions { + int num_streaming_threads{1}; ///< Number of streaming threads to use + int num_iterations{2}; ///< Number of iterations of query to run + int num_streams{16}; ///< Number of streams in stream pool + rapidsmpf::ndsh::CommType comm_type{ + rapidsmpf::ndsh::CommType::UCXX + }; ///< Type of communicator to create + cudf::size_type num_rows_per_chunk{ + 100'000'000 + }; ///< Number of rows to produce per chunk read + std::size_t num_producers{ + 1 + }; ///< Number of simultaneous read_parquet chunk producers. + std::size_t num_consumers{1}; ///< Number of simultaneous chunk consumers. + std::string input_directory; ///< Directory containing input files. + std::string input_file; ///< Basename of input file to read. + std::optional> columns{std::nullopt}; ///< Columns to read. +}; + +ProgramOptions parse_arguments(int argc, char** argv) { + ProgramOptions options; + + static constexpr std:: + array(rapidsmpf::ndsh::CommType::MAX)> + comm_names{"single", "mpi", "ucxx"}; + + auto print_usage = [&argv, &options]() { + std::cerr + << "Usage: " << argv[0] << " [options]\n" + << "Options:\n" + << " --num-streaming-threads Number of streaming threads (default: " + << options.num_streaming_threads << ")\n" + << " --num-iterations Number of iterations (default: " + << options.num_iterations << ")\n" + << " --num-streams Number of streams in stream pool " + "(default: " + << options.num_streams << ")\n" + << " --num-rows-per-chunk Number of rows per chunk (default: " + << options.num_rows_per_chunk << ")\n" + << " --num-producers Number of concurrent read_parquet " + "producers (default: " + << options.num_producers << ")\n" + << " --num-consumers Number of concurrent consumers (default: " + << options.num_consumers << ")\n" + << " --comm-type Communicator type: single, mpi, ucxx " + "(default: " + << comm_names[static_cast(options.comm_type)] << ")\n" + << " --input-directory Input directory path (required)\n" + << " --input-file Input file basename relative to input " + "directory (required)\n" + << " --columns Comma-separated column names to read " + "(optional, default all columns)\n" + << " --help Show this help message\n"; + }; + + // NOLINTBEGIN(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,modernize-use-designated-initializers) + static struct option long_options[] = { + {"num-streaming-threads", required_argument, nullptr, 1}, + {"num-rows-per-chunk", required_argument, nullptr, 2}, + {"num-producers", required_argument, nullptr, 3}, + {"num-consumers", required_argument, nullptr, 4}, + {"input-directory", required_argument, nullptr, 5}, + {"input-file", required_argument, nullptr, 6}, + {"help", no_argument, nullptr, 7}, + {"num-iterations", required_argument, nullptr, 8}, + {"num-streams", required_argument, nullptr, 9}, + {"comm-type", required_argument, nullptr, 10}, + {"columns", required_argument, nullptr, 11}, + {nullptr, 0, nullptr, 0} + }; + // NOLINTEND(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,modernize-use-designated-initializers) + + int opt; + int option_index = 0; + + bool saw_input_directory = false; + bool saw_input_file = false; + + auto parse_i64 = [](char const* s, char const* opt_name) -> long long { + if (s == nullptr || *s == '\0') { + std::cerr << "Error: " << opt_name << " requires a value\n"; + std::exit(1); + } + errno = 0; + char* end = nullptr; + auto const v = std::strtoll(s, &end, 10); + if (errno != 0 || end == s || *end != '\0') { + std::cerr << "Error: invalid integer for " << opt_name << ": '" << s << "'\n"; + std::exit(1); + } + return v; + }; + + auto parse_u64 = [](char const* s, char const* opt_name) -> unsigned long long { + if (s == nullptr || *s == '\0') { + std::cerr << "Error: " << opt_name << " requires a value\n"; + std::exit(1); + } + errno = 0; + char* end = nullptr; + auto const v = std::strtoull(s, &end, 10); + if (errno != 0 || end == s || *end != '\0') { + std::cerr << "Error: invalid non-negative integer for " << opt_name << ": '" + << s << "'\n"; + std::exit(1); + } + return v; + }; + + auto require_positive_i32 = [&](char const* s, char const* opt_name) -> int { + auto const v = parse_i64(s, opt_name); + if (v <= 0 || v > std::numeric_limits::max()) { + std::cerr << "Error: " << opt_name << " must be in [1, " + << std::numeric_limits::max() << "], got '" << s << "'\n"; + std::exit(1); + } + return static_cast(v); + }; + + auto require_positive_size_t = [&](char const* s, + char const* opt_name) -> std::size_t { + auto const v = parse_u64(s, opt_name); + if (v == 0 || v > std::numeric_limits::max()) { + std::cerr << "Error: " << opt_name << " must be in [1, " + << std::numeric_limits::max() << "], got '" << s + << "'\n"; + std::exit(1); + } + return static_cast(v); + }; + + auto parse_columns = [](char const* s) -> std::optional> { + if (s == nullptr) { + return std::nullopt; + } + std::string str{s}; + if (str.empty()) { + return std::nullopt; + } + std::vector cols; + std::size_t start = 0; + while (start <= str.size()) { + auto const comma = str.find(',', start); + auto const end = (comma == std::string::npos) ? str.size() : comma; + auto const token = str.substr(start, end - start); + if (token.empty()) { + std::cerr << "Error: --columns contains an empty column name\n"; + std::exit(1); + } + cols.push_back(token); + if (comma == std::string::npos) { + break; + } + start = comma + 1; + } + return cols; + }; + + while ((opt = getopt_long(argc, argv, "", long_options, &option_index)) != -1) { + switch (opt) { + case 1: // --num-streaming-threads + options.num_streaming_threads = + require_positive_i32(optarg, "--num-streaming-threads"); + break; + case 2: // --num-rows-per-chunk + options.num_rows_per_chunk = + require_positive_i32(optarg, "--num-rows-per-chunk"); + break; + case 3: // --num-producers + options.num_producers = require_positive_size_t(optarg, "--num-producers"); + break; + case 4: // --num-consumers + options.num_consumers = require_positive_size_t(optarg, "--num-consumers"); + break; + case 5: // --input-directory + if (optarg == nullptr || *optarg == '\0') { + std::cerr << "Error: --input-directory requires a non-empty value\n"; + std::exit(1); + } + options.input_directory = optarg; + saw_input_directory = true; + break; + case 6: // --input-file + if (optarg == nullptr || *optarg == '\0') { + std::cerr << "Error: --input-file requires a non-empty value\n"; + std::exit(1); + } + options.input_file = optarg; + saw_input_file = true; + break; + case 7: // --help + print_usage(); + std::exit(0); + case 8: // --num-iterations + options.num_iterations = require_positive_i32(optarg, "--num-iterations"); + break; + case 9: // --num-streams + options.num_streams = require_positive_i32(optarg, "--num-streams"); + break; + case 10: + { // --comm-type + if (optarg == nullptr || *optarg == '\0') { + std::cerr << "Error: --comm-type requires a value\n"; + std::exit(1); + } + std::string_view const s{optarg}; + auto parsed = std::optional{}; + for (std::size_t i = 0; i < comm_names.size(); ++i) { + if (s == comm_names[i]) { + parsed = static_cast(i); + break; + } + } + if (!parsed.has_value()) { + std::cerr << "Error: invalid --comm-type '" << s + << "' (expected: single, mpi, ucxx)\n"; + std::exit(1); + } + options.comm_type = *parsed; + break; + } + case 11: // --columns + options.columns = parse_columns(optarg); + break; + case '?': + if (optopt == 0 && optind > 1) { + std::cerr << "Error: Unknown option '" << argv[optind - 1] << "'\n\n"; + } + print_usage(); + std::exit(1); + default: + print_usage(); + std::exit(1); + } + } + + // Check if required options were provided + if (!saw_input_directory || !saw_input_file) { + if (!saw_input_directory) { + std::cerr << "Error: --input-directory is required\n"; + } + if (!saw_input_file) { + std::cerr << "Error: --input-file is required\n"; + } + std::cerr << std::endl; + print_usage(); + std::exit(1); + } + + return options; +} + +} // namespace + +/** + * @brief Run a simple benchmark reading a table from parquet files. + */ +int main(int argc, char** argv) { + rapidsmpf::ndsh::FinalizeMPI finalize{}; + cudaFree(nullptr); + // work around https://github.com/rapidsai/cudf/issues/20849 + cudf::initialize(); + auto mr = rmm::mr::cuda_async_memory_resource{}; + auto stats_wrapper = rapidsmpf::RmmResourceAdaptor(&mr); + auto arguments = parse_arguments(argc, argv); + rapidsmpf::ndsh::ProgramOptions ctx_arguments{ + .num_streaming_threads = arguments.num_streaming_threads, + .num_iterations = arguments.num_iterations, + .num_streams = arguments.num_streams, + .comm_type = arguments.comm_type, + .num_rows_per_chunk = arguments.num_rows_per_chunk, + .output_file = "", + .input_directory = arguments.input_directory + }; + + auto ctx = rapidsmpf::ndsh::create_context(ctx_arguments, &stats_wrapper); + std::vector timings; + for (int i = 0; i < arguments.num_iterations; i++) { + std::vector nodes; + auto start = std::chrono::steady_clock::now(); + { + RAPIDSMPF_NVTX_SCOPED_RANGE("Constructing read_parquet pipeline"); + + // Input data channels + auto ch_out = ctx->create_channel(); + auto ch_meta = ctx->create_channel(); + auto estimate = estimate_read_parquet_messages( + ctx, + arguments.input_directory, + arguments.input_file, + arguments.num_rows_per_chunk + ); + nodes.push_back(read_parquet( + ctx, + ch_out, + arguments.num_producers, + arguments.num_rows_per_chunk, + arguments.columns, + arguments.input_directory, + arguments.input_file + )); + nodes.push_back(advertise_message_count(ctx, ch_meta, estimate)); + nodes.push_back(consume_message_count(ctx, ch_meta)); + nodes.push_back( + consume_channel_parallel(ctx, ch_out, arguments.num_consumers) + ); + } + auto end = std::chrono::steady_clock::now(); + std::chrono::duration pipeline = end - start; + start = std::chrono::steady_clock::now(); + { + RAPIDSMPF_NVTX_SCOPED_RANGE("read_parquet iteration"); + rapidsmpf::streaming::run_streaming_pipeline(std::move(nodes)); + } + end = std::chrono::steady_clock::now(); + std::chrono::duration compute = end - start; + timings.push_back(pipeline.count()); + timings.push_back(compute.count()); + ctx->comm()->logger().print(ctx->statistics()->report()); + ctx->statistics()->clear(); + } + + if (ctx->comm()->rank() == 0) { + for (int i = 0; i < arguments.num_iterations; i++) { + ctx->comm()->logger().print( + "Iteration ", + i, + " pipeline construction time [s]: ", + timings[size_t(2 * i)] + ); + ctx->comm()->logger().print( + "Iteration ", i, " compute time [s]: ", timings[size_t(2 * i + 1)] + ); + } + } + return 0; +} From 885fc1047c0fb7f58d2375d3f83abaacbb0e3088 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 21 Jan 2026 11:42:55 +0000 Subject: [PATCH 03/14] Add adaptive_inner_join prototype --- .../streaming/ndsh/adaptive_join.cpp | 4 ++- cpp/benchmarks/streaming/ndsh/join.hpp | 32 ++++++++++++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp b/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp index 63e258731..9d36a6190 100644 --- a/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp +++ b/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp @@ -147,7 +147,9 @@ rapidsmpf::streaming::Node consume_channel_parallel( ctx, msg.release() ); ctx->comm()->logger().print( - "Consumed chunk with ", + "Consumed chunk ", + msg.sequence_number(), + " with ", chunk.table_view().num_rows(), " rows and ", chunk.table_view().num_columns(), diff --git a/cpp/benchmarks/streaming/ndsh/join.hpp b/cpp/benchmarks/streaming/ndsh/join.hpp index 67e701817..b05299d65 100644 --- a/cpp/benchmarks/streaming/ndsh/join.hpp +++ b/cpp/benchmarks/streaming/ndsh/join.hpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 */ @@ -141,4 +141,34 @@ enum class KeepKeys : bool { OpID tag ); +/** + * @brief Perform a streaming inner join between two tables using an adaptive strategy. + * + * @note This inspects a small prefix of each input channel to estimate input sizes and + * uses that estimate to choose between broadcast and shuffle join implementations. + * + * @param ctx Streaming context. + * @param left Channel of `TableChunk`s from the left table. + * @param right Channel of `TableChunk`s from the right table. + * @param ch_out Output channel of `TableChunk`s. + * @param left_keys Column indices of the keys in the left table. + * @param right_keys Column indices of the keys in the right table. + * @param allreduce_tag Disambiguating tag for the size estimation allgather. + * @param left_shuffle_tag Disambiguating tag for the left shuffle (if used). + * @param right_shuffle_tag Disambiguating tag for the right shuffle (if used). + * + * @return Coroutine representing the completion of the join. + */ +streaming::Node adaptive_inner_join( + std::shared_ptr ctx, + std::shared_ptr left, + std::shared_ptr right, + std::shared_ptr ch_out, + std::vector left_keys, + std::vector right_keys, + OpID allreduce_tag, + OpID left_shuffle_tag, + OpID right_shuffle_tag +); + } // namespace rapidsmpf::ndsh From 3ee7bd1237da35d60653ab18538a387114bd2720 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 21 Jan 2026 16:20:22 +0000 Subject: [PATCH 04/14] Fix bug in early exit of Lineariser when output is shutdown We need to break the outer loop, not the inner one. --- cpp/include/rapidsmpf/streaming/core/lineariser.hpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cpp/include/rapidsmpf/streaming/core/lineariser.hpp b/cpp/include/rapidsmpf/streaming/core/lineariser.hpp index edad29951..237bb4da0 100644 --- a/cpp/include/rapidsmpf/streaming/core/lineariser.hpp +++ b/cpp/include/rapidsmpf/streaming/core/lineariser.hpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 */ @@ -94,7 +94,8 @@ class Lineariser { Node drain() { ShutdownAtExit c{ch_out_}; co_await ctx_->executor()->schedule(); - while (!queues_.empty()) { + bool should_continue = true; + while (should_continue && !queues_.empty()) { for (auto& q : queues_) { auto [receipt, msg] = co_await q->receive(); if (msg.empty()) { @@ -103,6 +104,7 @@ class Lineariser { } if (!co_await ch_out_->send(std::move(msg))) { // Output channel is shut down, tell the producers to shutdown. + should_continue = false; break; } co_await receipt; From 93d726cd45cc082541cb8fb58f224cda5b2df94e Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 21 Jan 2026 16:36:41 +0000 Subject: [PATCH 05/14] Work on adaptive join, starting to hook things up --- .../streaming/ndsh/adaptive_join.cpp | 196 +++++++++++------- cpp/benchmarks/streaming/ndsh/join.cpp | 84 ++++---- cpp/benchmarks/streaming/ndsh/join.hpp | 4 + 3 files changed, 168 insertions(+), 116 deletions(-) diff --git a/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp b/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp index 9d36a6190..eb4c24be8 100644 --- a/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp +++ b/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp @@ -32,6 +32,7 @@ #include #include +#include "join.hpp" #include "utils.hpp" #include @@ -60,7 +61,7 @@ rapidsmpf::streaming::Node read_parquet( ); } -std::size_t estimate_read_parquet_messages( +[[maybe_unused]] std::size_t estimate_read_parquet_messages( std::shared_ptr ctx, std::string const& input_directory, std::string const& input_file, @@ -100,7 +101,7 @@ std::size_t estimate_read_parquet_messages( return (total_rows + chunk_rows - 1) / chunk_rows; } -rapidsmpf::streaming::Node advertise_message_count( +[[maybe_unused]] rapidsmpf::streaming::Node advertise_message_count( std::shared_ptr ctx, std::shared_ptr ch_meta, std::size_t estimate @@ -110,64 +111,40 @@ rapidsmpf::streaming::Node advertise_message_count( auto payload = std::make_unique(estimate); co_await ch_meta->send(rapidsmpf::streaming::Message{0, std::move(payload), {}, {}}); co_await ch_meta->drain(ctx->executor()); + ctx->comm()->logger().print("Exiting message count"); } -rapidsmpf::streaming::Node consume_message_count( - std::shared_ptr ctx, - std::shared_ptr ch_meta -) { - rapidsmpf::streaming::ShutdownAtExit c{ch_meta}; - co_await ctx->executor()->schedule(); - auto msg = co_await ch_meta->receive(); - if (!msg.empty()) { - RAPIDSMPF_EXPECTS( - msg.holds(), "Expected size_t message count estimate" - ); - ctx->comm()->logger().print("Estimated message count: ", msg.get()); - } - co_await ch_meta->drain(ctx->executor()); -} - -rapidsmpf::streaming::Node consume_channel_parallel( +[[maybe_unused]] rapidsmpf::streaming::Node consume_channel_parallel( std::shared_ptr ctx, std::shared_ptr ch_in, - std::size_t num_consumers + std::size_t ) { rapidsmpf::streaming::ShutdownAtExit c{ch_in}; - std::atomic estimated_total_bytes{0}; - auto task = [&]() -> rapidsmpf::streaming::Node { - co_await ctx->executor()->schedule(); - while (true) { - auto msg = co_await ch_in->receive(); - if (msg.empty()) { - break; - } - if (msg.holds()) { - auto chunk = rapidsmpf::ndsh::to_device( - ctx, msg.release() - ); - ctx->comm()->logger().print( - "Consumed chunk ", - msg.sequence_number(), - " with ", - chunk.table_view().num_rows(), - " rows and ", - chunk.table_view().num_columns(), - " columns" - ); - estimated_total_bytes.fetch_add( - chunk.data_alloc_size(rapidsmpf::MemoryType::DEVICE) - ); - } + std::size_t estimated_total_bytes{0}; + co_await ctx->executor()->schedule(); + while (true) { + auto msg = co_await ch_in->receive(); + if (msg.empty()) { + break; + } + if (msg.holds()) { + auto chunk = rapidsmpf::ndsh::to_device( + ctx, msg.release() + ); + ctx->comm()->logger().print( + "Consumed chunk ", + msg.sequence_number(), + " with ", + chunk.table_view().num_rows(), + " rows and ", + chunk.table_view().num_columns(), + " columns" + ); + estimated_total_bytes += chunk.data_alloc_size(rapidsmpf::MemoryType::DEVICE); } - }; - std::vector tasks; - for (std::size_t i = 0; i < num_consumers; i++) { - tasks.push_back(task()); } - rapidsmpf::streaming::coro_results(co_await coro::when_all(std::move(tasks))); ctx->comm()->logger().print( - "Table was around ", rmm::detail::format_bytes(estimated_total_bytes.load()) + "Table was around ", rmm::detail::format_bytes(estimated_total_bytes) ); } @@ -187,8 +164,14 @@ struct ProgramOptions { }; ///< Number of simultaneous read_parquet chunk producers. std::size_t num_consumers{1}; ///< Number of simultaneous chunk consumers. std::string input_directory; ///< Directory containing input files. - std::string input_file; ///< Basename of input file to read. - std::optional> columns{std::nullopt}; ///< Columns to read. + std::string left_input_file; ///< Basename of left input file to read. + std::string right_input_file; ///< Basename of right input file to read. + std::optional> left_columns{ + std::nullopt + }; ///< Columns to read (left input). + std::optional> right_columns{ + std::nullopt + }; ///< Columns to read (right input). }; ProgramOptions parse_arguments(int argc, char** argv) { @@ -220,9 +203,14 @@ ProgramOptions parse_arguments(int argc, char** argv) { "(default: " << comm_names[static_cast(options.comm_type)] << ")\n" << " --input-directory Input directory path (required)\n" - << " --input-file Input file basename relative to input " + << " --left-input-file Left input file basename relative to " + "input " "directory (required)\n" - << " --columns Comma-separated column names to read " + << " --right-input-file Right input file basename relative to " + "input directory (required)\n" + << " --left-columns Comma-separated column names to read " + "(optional, default all columns)\n" + << " --right-columns Comma-separated column names to read " "(optional, default all columns)\n" << " --help Show this help message\n"; }; @@ -234,12 +222,14 @@ ProgramOptions parse_arguments(int argc, char** argv) { {"num-producers", required_argument, nullptr, 3}, {"num-consumers", required_argument, nullptr, 4}, {"input-directory", required_argument, nullptr, 5}, - {"input-file", required_argument, nullptr, 6}, + {"left-input-file", required_argument, nullptr, 6}, + {"right-input-file", required_argument, nullptr, 12}, {"help", no_argument, nullptr, 7}, {"num-iterations", required_argument, nullptr, 8}, {"num-streams", required_argument, nullptr, 9}, {"comm-type", required_argument, nullptr, 10}, - {"columns", required_argument, nullptr, 11}, + {"left-columns", required_argument, nullptr, 11}, + {"right-columns", required_argument, nullptr, 13}, {nullptr, 0, nullptr, 0} }; // NOLINTEND(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,modernize-use-designated-initializers) @@ -248,7 +238,8 @@ ProgramOptions parse_arguments(int argc, char** argv) { int option_index = 0; bool saw_input_directory = false; - bool saw_input_file = false; + bool saw_left_input_file = false; + bool saw_right_input_file = false; auto parse_i64 = [](char const* s, char const* opt_name) -> long long { if (s == nullptr || *s == '\0') { @@ -354,13 +345,21 @@ ProgramOptions parse_arguments(int argc, char** argv) { options.input_directory = optarg; saw_input_directory = true; break; - case 6: // --input-file + case 6: // --left-input-file + if (optarg == nullptr || *optarg == '\0') { + std::cerr << "Error: --left-input-file requires a non-empty value\n"; + std::exit(1); + } + options.left_input_file = optarg; + saw_left_input_file = true; + break; + case 12: // --right-input-file if (optarg == nullptr || *optarg == '\0') { - std::cerr << "Error: --input-file requires a non-empty value\n"; + std::cerr << "Error: --right-input-file requires a non-empty value\n"; std::exit(1); } - options.input_file = optarg; - saw_input_file = true; + options.right_input_file = optarg; + saw_right_input_file = true; break; case 7: // --help print_usage(); @@ -393,8 +392,11 @@ ProgramOptions parse_arguments(int argc, char** argv) { options.comm_type = *parsed; break; } - case 11: // --columns - options.columns = parse_columns(optarg); + case 11: // --left-columns + options.left_columns = parse_columns(optarg); + break; + case 13: // --right-columns + options.right_columns = parse_columns(optarg); break; case '?': if (optopt == 0 && optind > 1) { @@ -409,12 +411,15 @@ ProgramOptions parse_arguments(int argc, char** argv) { } // Check if required options were provided - if (!saw_input_directory || !saw_input_file) { + if (!saw_input_directory || !saw_left_input_file || !saw_right_input_file) { if (!saw_input_directory) { std::cerr << "Error: --input-directory is required\n"; } - if (!saw_input_file) { - std::cerr << "Error: --input-file is required\n"; + if (!saw_left_input_file) { + std::cerr << "Error: --left-input-file is required\n"; + } + if (!saw_right_input_file) { + std::cerr << "Error: --right-input-file is required\n"; } std::cerr << std::endl; print_usage(); @@ -451,32 +456,69 @@ int main(int argc, char** argv) { std::vector timings; for (int i = 0; i < arguments.num_iterations; i++) { std::vector nodes; + int op_id = 0; auto start = std::chrono::steady_clock::now(); { RAPIDSMPF_NVTX_SCOPED_RANGE("Constructing read_parquet pipeline"); // Input data channels - auto ch_out = ctx->create_channel(); - auto ch_meta = ctx->create_channel(); - auto estimate = estimate_read_parquet_messages( + auto left_out = ctx->create_channel(); + auto right_out = ctx->create_channel(); + auto left_meta = ctx->create_channel(); + auto right_meta = ctx->create_channel(); + auto left_estimate = estimate_read_parquet_messages( ctx, arguments.input_directory, - arguments.input_file, + arguments.left_input_file, + arguments.num_rows_per_chunk + ); + auto right_estimate = estimate_read_parquet_messages( + ctx, + arguments.input_directory, + arguments.right_input_file, arguments.num_rows_per_chunk ); nodes.push_back(read_parquet( ctx, - ch_out, + left_out, arguments.num_producers, arguments.num_rows_per_chunk, - arguments.columns, + arguments.left_columns, arguments.input_directory, - arguments.input_file + arguments.left_input_file )); - nodes.push_back(advertise_message_count(ctx, ch_meta, estimate)); - nodes.push_back(consume_message_count(ctx, ch_meta)); + nodes.push_back(read_parquet( + ctx, + right_out, + arguments.num_producers, + arguments.num_rows_per_chunk, + arguments.right_columns, + arguments.input_directory, + arguments.right_input_file + )); + nodes.push_back(advertise_message_count(ctx, left_meta, left_estimate)); + nodes.push_back(advertise_message_count(ctx, right_meta, right_estimate)); + auto joined = ctx->create_channel(); + auto const size_tag = static_cast(10 * i) + op_id++; + auto const left_shuffle_tag = static_cast(10 * i) + op_id++; + auto const right_shuffle_tag = static_cast(10 * i) + op_id++; + nodes.push_back( + rapidsmpf::ndsh::adaptive_inner_join( + ctx, + left_out, + right_out, + left_meta, + right_meta, + joined, + {0}, + {0}, + size_tag, + left_shuffle_tag, + right_shuffle_tag + ) + ); nodes.push_back( - consume_channel_parallel(ctx, ch_out, arguments.num_consumers) + consume_channel_parallel(ctx, joined, arguments.num_consumers) ); } auto end = std::chrono::steady_clock::now(); diff --git a/cpp/benchmarks/streaming/ndsh/join.cpp b/cpp/benchmarks/streaming/ndsh/join.cpp index 53d995f64..6749ebbfe 100644 --- a/cpp/benchmarks/streaming/ndsh/join.cpp +++ b/cpp/benchmarks/streaming/ndsh/join.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -35,6 +36,7 @@ #include #include #include +#include #include #include @@ -411,41 +413,37 @@ namespace { coro::task> allgather_join_sizes( std::shared_ptr ctx, OpID tag, - std::size_t left_sample_bytes, - std::size_t right_sample_bytes + std::size_t left_local_bytes, + std::size_t right_local_bytes ) { auto metadata = std::make_unique>(2 * sizeof(std::size_t)); - std::memcpy(metadata->data(), &left_sample_bytes, sizeof(left_sample_bytes)); + std::memcpy(metadata->data(), &left_local_bytes, sizeof(left_local_bytes)); std::memcpy( - metadata->data() + sizeof(left_sample_bytes), - &right_sample_bytes, - sizeof(right_sample_bytes) + metadata->data() + sizeof(left_local_bytes), + &right_local_bytes, + sizeof(right_local_bytes) ); auto stream = ctx->br()->stream_pool().get_stream(); auto [res, _] = ctx->br()->reserve(MemoryType::HOST, 0, true); auto buf = ctx->br()->allocate(stream, std::move(res)); auto allgather = streaming::AllGather(ctx, tag); - allgather.insert(0, {PackedData(std::move(metadata), std::move(buf))}); + allgather.insert(0, {std::move(metadata), std::move(buf)}); allgather.insert_finished(); auto per_rank = co_await allgather.extract_all(streaming::AllGather::Ordered::NO); std::size_t left_total_bytes = 0; std::size_t right_total_bytes = 0; for (auto const& data : per_rank) { - // Assumption: each rank packs two size_t values into metadata. RAPIDSMPF_EXPECTS( data.metadata->size() >= 2 * sizeof(std::size_t), "Invalid metadata size for adaptive join size estimation" ); - std::size_t rank_left = 0; - std::size_t rank_right = 0; - std::memcpy(&rank_left, data.metadata->data(), sizeof(rank_left)); - std::memcpy( - &rank_right, data.metadata->data() + sizeof(rank_left), sizeof(rank_right) - ); - left_total_bytes += rank_left; - right_total_bytes += rank_right; + std::size_t bytes = 0; + std::memcpy(&bytes, data.metadata->data(), sizeof(bytes)); + left_total_bytes += bytes; + std::memcpy(&bytes, data.metadata->data() + sizeof(bytes), sizeof(bytes)); + right_total_bytes += bytes; } co_return {left_total_bytes, right_total_bytes}; } @@ -455,6 +453,8 @@ streaming::Node adaptive_inner_join( std::shared_ptr ctx, std::shared_ptr left, std::shared_ptr right, + std::shared_ptr left_meta, + std::shared_ptr right_meta, std::shared_ptr ch_out, std::vector left_keys, std::vector right_keys, @@ -462,65 +462,71 @@ streaming::Node adaptive_inner_join( OpID left_shuffle_tag, OpID right_shuffle_tag ) { - streaming::ShutdownAtExit c{left, right, ch_out}; + streaming::ShutdownAtExit c{left, right, left_meta, right_meta, ch_out}; co_await ctx->executor()->schedule(); (void)left_keys; (void)right_keys; - (void)allreduce_tag; (void)left_shuffle_tag; (void)right_shuffle_tag; + + auto consume_meta = [&]( + std::shared_ptr ch + ) -> coro::task> { + auto msg = co_await ch->receive(); + if (msg.empty()) { + co_return std::nullopt; + } + co_return msg.release(); + }; + auto [num_left_messages, num_right_messages] = streaming::coro_results( + co_await coro::when_all(consume_meta(left_meta), consume_meta(right_meta)) + ); // Assumption: the input channels carry only TableChunk messages. // Assumption: summing data_alloc_size across memory types is a good proxy for the // amount of data that will need to be materialized on device for compute. // Assumption: a small sample of chunks is representative of the whole table size. + // Assumption: metadata estimates reflect the total number of chunks per input. constexpr std::size_t inspect_messages = 2; std::vector left_buffer; std::vector right_buffer; - left_buffer.reserve(inspect_messages + 1); - right_buffer.reserve(inspect_messages + 1); + left_buffer.reserve(inspect_messages); + right_buffer.reserve(inspect_messages); auto inspect_channel = [&](std::shared_ptr ch, - std::vector& buffer) -> coro::task { + std::vector& buffer, + std::size_t estimated_num_messages) -> coro::task { std::size_t bytes = 0; for (std::size_t i = 0; i < inspect_messages; ++i) { auto msg = co_await ch->receive(); if (msg.empty()) { - // Preserve the termination marker for downstream processing. buffer.push_back(std::move(msg)); co_return bytes; } - RAPIDSMPF_EXPECTS( - msg.holds(), - "adaptive_inner_join expects TableChunk messages" - ); auto const& chunk = msg.get(); for (auto mem_type : MEMORY_TYPES) { bytes += chunk.data_alloc_size(mem_type); } buffer.push_back(std::move(msg)); } - co_return bytes; + co_return (bytes * estimated_num_messages) / inspect_messages; }; - auto left_sample_bytes = co_await inspect_channel(left, left_buffer); - auto right_sample_bytes = co_await inspect_channel(right, right_buffer); + auto left_local_bytes = co_await inspect_channel( + left, left_buffer, num_left_messages.value_or(inspect_messages) + ); + auto right_local_bytes = co_await inspect_channel( + right, right_buffer, num_right_messages.value_or(inspect_messages) + ); - std::size_t left_total_bytes = left_sample_bytes; - std::size_t right_total_bytes = right_sample_bytes; + std::size_t left_total_bytes = left_local_bytes; + std::size_t right_total_bytes = right_local_bytes; if (ctx->comm()->nranks() > 1) { std::tie(left_total_bytes, right_total_bytes) = co_await allgather_join_sizes( - ctx, allreduce_tag, left_sample_bytes, right_sample_bytes + ctx, allreduce_tag, left_local_bytes, right_local_bytes ); } - ctx->comm()->logger().print( - "Adaptive join sample sizes: left ", - left_sample_bytes, - " bytes, right ", - right_sample_bytes, - " bytes" - ); ctx->comm()->logger().print( "Adaptive join total sizes: left ", left_total_bytes, diff --git a/cpp/benchmarks/streaming/ndsh/join.hpp b/cpp/benchmarks/streaming/ndsh/join.hpp index b05299d65..1a46eb50a 100644 --- a/cpp/benchmarks/streaming/ndsh/join.hpp +++ b/cpp/benchmarks/streaming/ndsh/join.hpp @@ -150,6 +150,8 @@ enum class KeepKeys : bool { * @param ctx Streaming context. * @param left Channel of `TableChunk`s from the left table. * @param right Channel of `TableChunk`s from the right table. + * @param left_meta Channel carrying metadata for the left input (optional). + * @param right_meta Channel carrying metadata for the right input (optional). * @param ch_out Output channel of `TableChunk`s. * @param left_keys Column indices of the keys in the left table. * @param right_keys Column indices of the keys in the right table. @@ -163,6 +165,8 @@ streaming::Node adaptive_inner_join( std::shared_ptr ctx, std::shared_ptr left, std::shared_ptr right, + std::shared_ptr left_meta, + std::shared_ptr right_meta, std::shared_ptr ch_out, std::vector left_keys, std::vector right_keys, From ca65c71e5dc9bdb28f00cb4d00bb7533539039da Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 21 Jan 2026 17:26:17 +0000 Subject: [PATCH 06/14] Implement a replay_channel utility node Having buffered some messages to inspect for size estimation we need to replay the full channel to interact with the broadcast and shuffle nodes. We do this by creating an output channel and sending in first the buffered messages and then consuming the remainder of the output channel. --- cpp/benchmarks/streaming/ndsh/join.cpp | 27 ++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/cpp/benchmarks/streaming/ndsh/join.cpp b/cpp/benchmarks/streaming/ndsh/join.cpp index 6749ebbfe..fafe34a7c 100644 --- a/cpp/benchmarks/streaming/ndsh/join.cpp +++ b/cpp/benchmarks/streaming/ndsh/join.cpp @@ -447,6 +447,33 @@ coro::task> allgather_join_sizes( } co_return {left_total_bytes, right_total_bytes}; } + +streaming::Node replay_channel( + std::shared_ptr ctx, + std::shared_ptr input, + std::shared_ptr output, + std::vector buffer +) { + streaming::ShutdownAtExit c{input, output}; + co_await ctx->executor()->schedule(); + for (auto&& msg : buffer) { + if (msg.empty()) { + co_await output->drain(ctx->executor()); + co_return; + } + if (!co_await output->send(std::move(msg))) { + co_return; + } + } + while (!output->is_shutdown()) { + auto msg = co_await input->receive(); + if (msg.empty()) { + break; + } + co_await output->send(std::move(msg)); + } + co_await output->drain(ctx->executor()); +} } // namespace streaming::Node adaptive_inner_join( From db4a2e287c0041ee0a89e0f179090929a99203c5 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 21 Jan 2026 17:38:13 +0000 Subject: [PATCH 07/14] Implement size-adaptive shuffle or broadcast join Using the new channel replay, and with statistics gathered from the buffered messages, dispatch to either a broadcast join or a shuffle join depending on the estimated sizes of the two tables. --- cpp/benchmarks/streaming/ndsh/join.cpp | 55 ++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 4 deletions(-) diff --git a/cpp/benchmarks/streaming/ndsh/join.cpp b/cpp/benchmarks/streaming/ndsh/join.cpp index fafe34a7c..b010b904b 100644 --- a/cpp/benchmarks/streaming/ndsh/join.cpp +++ b/cpp/benchmarks/streaming/ndsh/join.cpp @@ -491,10 +491,6 @@ streaming::Node adaptive_inner_join( ) { streaming::ShutdownAtExit c{left, right, left_meta, right_meta, ch_out}; co_await ctx->executor()->schedule(); - (void)left_keys; - (void)right_keys; - (void)left_shuffle_tag; - (void)right_shuffle_tag; auto consume_meta = [&]( std::shared_ptr ch @@ -513,6 +509,8 @@ streaming::Node adaptive_inner_join( // amount of data that will need to be materialized on device for compute. // Assumption: a small sample of chunks is representative of the whole table size. // Assumption: metadata estimates reflect the total number of chunks per input. + constexpr std::size_t broadcast_cap_bytes = 8ULL * 1024ULL * 1024ULL * 1024ULL; + constexpr double broadcast_ratio_threshold = 0.10; constexpr std::size_t inspect_messages = 2; std::vector left_buffer; std::vector right_buffer; @@ -561,6 +559,55 @@ streaming::Node adaptive_inner_join( right_total_bytes, " bytes" ); + auto const min_bytes = std::min(left_total_bytes, right_total_bytes); + auto const max_bytes = std::max(left_total_bytes, right_total_bytes); + auto const broadcast_ratio = + (max_bytes == 0) ? 0.0 : static_cast(min_bytes) / max_bytes; + auto const use_broadcast = + min_bytes <= broadcast_cap_bytes && broadcast_ratio <= broadcast_ratio_threshold; + auto left_replay = ctx->create_channel(); + auto right_replay = ctx->create_channel(); + std::vector tasks; + tasks.push_back(replay_channel(ctx, left, left_replay, std::move(left_buffer))); + tasks.push_back(replay_channel(ctx, right, right_replay, std::move(right_buffer))); + if (use_broadcast) { + ctx->comm()->logger().print("Adaptive join strategy: broadcast"); + auto const broadcast_left = left_total_bytes <= right_total_bytes; + auto build = broadcast_left ? left_replay : right_replay; + auto probe = broadcast_left ? right_replay : left_replay; + auto build_keys = broadcast_left ? left_keys : right_keys; + auto probe_keys = broadcast_left ? right_keys : left_keys; + auto const broadcast_tag = broadcast_left ? left_shuffle_tag : right_shuffle_tag; + tasks.push_back(inner_join_broadcast( + ctx, + build, + probe, + ch_out, + std::move(build_keys), + std::move(probe_keys), + broadcast_tag + )); + } else { + ctx->comm()->logger().print("Adaptive join strategy: shuffle"); + auto const num_partitions = static_cast(ctx->comm()->nranks()); + auto left_shuffled = ctx->create_channel(); + auto right_shuffled = ctx->create_channel(); + tasks.push_back(shuffle( + ctx, left_replay, left_shuffled, left_keys, num_partitions, left_shuffle_tag + )); + tasks.push_back(shuffle( + ctx, + right_replay, + right_shuffled, + right_keys, + num_partitions, + right_shuffle_tag + )); + tasks.push_back(inner_join_shuffle( + ctx, left_shuffled, right_shuffled, ch_out, left_keys, right_keys + )); + } + streaming::coro_results(co_await coro::when_all(std::move(tasks))); co_await ch_out->drain(ctx->executor()); } } // namespace rapidsmpf::ndsh From d69c87e36d4d049431e7508a88d9a30d486dc7f6 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 21 Jan 2026 18:18:44 +0000 Subject: [PATCH 08/14] WIP: broadcast join accepts join side --- cpp/benchmarks/streaming/ndsh/join.cpp | 126 ++++++++++++++++++++++--- cpp/benchmarks/streaming/ndsh/join.hpp | 26 +++-- 2 files changed, 132 insertions(+), 20 deletions(-) diff --git a/cpp/benchmarks/streaming/ndsh/join.cpp b/cpp/benchmarks/streaming/ndsh/join.cpp index b010b904b..0d2feb253 100644 --- a/cpp/benchmarks/streaming/ndsh/join.cpp +++ b/cpp/benchmarks/streaming/ndsh/join.cpp @@ -227,23 +227,103 @@ streaming::Message inner_join_chunk( ); } +streaming::Message inner_join_chunk_probe_first( + std::shared_ptr ctx, + streaming::TableChunk&& probe_chunk, + std::uint64_t sequence, + cudf::hash_join& joiner, + cudf::table_view build_table, + std::vector build_keys, + std::vector probe_keys, + KeepKeys keep_keys, + rmm::cuda_stream_view build_stream, + CudaEvent* build_event +) { + CudaEvent event; + probe_chunk = to_device(ctx, std::move(probe_chunk)); + auto chunk_stream = probe_chunk.stream(); + build_event->stream_wait(chunk_stream); + auto probe_table = probe_chunk.table_view(); + auto probe_table_keys = probe_table.select(probe_keys); + auto [probe_match, build_match] = joiner.inner_join( + probe_table_keys, std::nullopt, chunk_stream, ctx->br()->device_mr() + ); + + cudf::column_view build_indices = + cudf::device_span(*build_match); + cudf::column_view probe_indices = + cudf::device_span(*probe_match); + + cudf::table_view probe_carrier; + if (keep_keys == KeepKeys::YES) { + probe_carrier = probe_table; + } else { + std::vector probe_to_keep; + std::ranges::copy_if( + std::ranges::iota_view(0, probe_table.num_columns()), + std::back_inserter(probe_to_keep), + [&](auto i) { return std::ranges::find(probe_keys, i) == probe_keys.end(); } + ); + probe_carrier = probe_table.select(probe_to_keep); + } + auto result_columns = cudf::gather( + probe_carrier, + probe_indices, + cudf::out_of_bounds_policy::DONT_CHECK, + chunk_stream, + ctx->br()->device_mr() + ) + ->release(); + + std::vector build_to_keep; + std::ranges::copy_if( + std::ranges::iota_view(0, build_table.num_columns()), + std::back_inserter(build_to_keep), + [&](auto i) { return std::ranges::find(build_keys, i) == build_keys.end(); } + ); + if (!build_to_keep.empty()) { + std::ranges::move( + cudf::gather( + build_table.select(build_to_keep), + build_indices, + cudf::out_of_bounds_policy::DONT_CHECK, + chunk_stream, + ctx->br()->device_mr() + ) + ->release(), + std::back_inserter(result_columns) + ); + } + cuda_stream_join(build_stream, chunk_stream, &event); + return streaming::to_message( + sequence, + std::make_unique( + std::make_unique(std::move(result_columns)), chunk_stream + ) + ); +} + streaming::Node inner_join_broadcast( std::shared_ptr ctx, - // We will always choose left as build table and do "broadcast" joins std::shared_ptr left, std::shared_ptr right, std::shared_ptr ch_out, std::vector left_on, std::vector right_on, OpID tag, - KeepKeys keep_keys + KeepKeys keep_keys, + BroadcastSide broadcast_side ) { streaming::ShutdownAtExit c{left, right, ch_out}; co_await ctx->executor()->schedule(); ctx->comm()->logger().print("Inner broadcast join ", static_cast(tag)); + auto build = broadcast_side == BroadcastSide::LEFT ? left : right; + auto probe = broadcast_side == BroadcastSide::LEFT ? right : left; + auto build_keys = broadcast_side == BroadcastSide::LEFT ? left_on : right_on; + auto probe_keys = broadcast_side == BroadcastSide::LEFT ? right_on : left_on; auto build_table = to_device( ctx, - (co_await broadcast(ctx, left, tag, streaming::AllGather::Ordered::NO)) + (co_await broadcast(ctx, build, tag, streaming::AllGather::Ordered::NO)) .release() ); ctx->comm()->logger().print( @@ -251,12 +331,34 @@ streaming::Node inner_join_broadcast( ); auto joiner = cudf::hash_join( - build_table.table_view().select(left_on), + build_table.table_view().select(build_keys), cudf::null_equality::UNEQUAL, build_table.stream() ); CudaEvent build_event; build_event.record(build_table.stream()); + if (broadcast_side == BroadcastSide::RIGHT) { + while (!ch_out->is_shutdown()) { + auto probe_msg = co_await probe->receive(); + if (probe_msg.empty()) { + break; + } + co_await ch_out->send(inner_join_chunk_probe_first( + ctx, + probe_msg.release(), + probe_msg.sequence_number(), + joiner, + build_table.table_view(), + build_keys, + probe_keys, + keep_keys, + build_table.stream(), + &build_event + )); + } + co_await ch_out->drain(ctx->executor()); + co_return; + } cudf::table_view build_carrier; if (keep_keys == KeepKeys::YES) { build_carrier = build_table.table_view(); @@ -265,22 +367,22 @@ streaming::Node inner_join_broadcast( std::ranges::copy_if( std::ranges::iota_view(0, build_table.table_view().num_columns()), std::back_inserter(to_keep), - [&](auto i) { return std::ranges::find(left_on, i) == left_on.end(); } + [&](auto i) { return std::ranges::find(build_keys, i) == build_keys.end(); } ); build_carrier = build_table.table_view().select(to_keep); } while (!ch_out->is_shutdown()) { - auto right_msg = co_await right->receive(); - if (right_msg.empty()) { + auto probe_msg = co_await probe->receive(); + if (probe_msg.empty()) { break; } co_await ch_out->send(inner_join_chunk( ctx, - right_msg.release(), - right_msg.sequence_number(), + probe_msg.release(), + probe_msg.sequence_number(), joiner, build_carrier, - right_on, + probe_keys, build_table.stream(), &build_event )); @@ -585,7 +687,9 @@ streaming::Node adaptive_inner_join( ch_out, std::move(build_keys), std::move(probe_keys), - broadcast_tag + broadcast_tag, + KeepKeys::YES, + broadcast_left ? BroadcastSide::LEFT : BroadcastSide::RIGHT )); } else { ctx->comm()->logger().print("Adaptive join strategy: shuffle"); diff --git a/cpp/benchmarks/streaming/ndsh/join.hpp b/cpp/benchmarks/streaming/ndsh/join.hpp index 1a46eb50a..09f9721e9 100644 --- a/cpp/benchmarks/streaming/ndsh/join.hpp +++ b/cpp/benchmarks/streaming/ndsh/join.hpp @@ -22,6 +22,12 @@ enum class KeepKeys : bool { YES, ///< Key columns do appear in the output }; +///< @brief Which input is the build side for a join. +enum class BroadcastSide : bool { + LEFT, ///< Broadcast the left input + RIGHT, ///< Broadcast the right input +}; + /** * @brief Broadcast the concatenation of all input messages to all ranks. * @@ -67,31 +73,33 @@ enum class KeepKeys : bool { /** * @brief Perform a streaming inner join between two tables. * - * @note This performs a broadcast join, broadcasting the table represented by the `left` - * channel to all ranks, and then streaming through the chunks of the `right` channel. + * @note This performs a broadcast join, broadcasting the chosen broadcast side to all + * ranks and then streaming through the chunks of the non-broadcasted side. Output + * ordering follows the logical left input: left keys (if kept), left non-keys, then right + * non-keys. * * @param ctx Streaming context. - * @param left Channel of `TableChunk`s used as the broadcasted build side. - * @param right Channel of `TableChunk`s joined in turn against the build side. + * @param left Channel of `TableChunk`s for the left input. + * @param right Channel of `TableChunk`s for the right input. * @param ch_out Output channel of `TableChunk`s. * @param left_on Column indices of the keys in the left table. * @param right_on Column indices of the keys in the right table. - * @param tag Disambiguating tag for the broadcast of the left table. - * @param keep_keys Does the result contain the key columns, or only "carrier" value - * columns + * @param tag Disambiguating tag for the broadcast. + * @param keep_keys Does the result contain the left key columns. + * @param broadcast_side Which input should be broadcast. * * @return Coroutine representing the completion of the join. */ [[nodiscard]] streaming::Node inner_join_broadcast( std::shared_ptr ctx, - // We will always choose left as build table and do "broadcast" joins std::shared_ptr left, std::shared_ptr right, std::shared_ptr ch_out, std::vector left_on, std::vector right_on, OpID tag, - KeepKeys keep_keys = KeepKeys::YES + KeepKeys keep_keys = KeepKeys::YES, + BroadcastSide broadcast_side = BroadcastSide::LEFT ); /** * @brief Perform a streaming inner join between two tables. From 45f15e26978b81984ee121b68f93968932e62827 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 22 Jan 2026 11:18:31 +0000 Subject: [PATCH 09/14] Refactor join chunk output assembly Use broadcast side to drive output column ordering and build/probe carrier selection via filtered views. Replace the duplicate join chunk helpers and update callers. --- cpp/benchmarks/streaming/ndsh/join.cpp | 237 ++++++++----------------- 1 file changed, 78 insertions(+), 159 deletions(-) diff --git a/cpp/benchmarks/streaming/ndsh/join.cpp b/cpp/benchmarks/streaming/ndsh/join.cpp index 0d2feb253..15640c3f0 100644 --- a/cpp/benchmarks/streaming/ndsh/join.cpp +++ b/cpp/benchmarks/streaming/ndsh/join.cpp @@ -150,96 +150,40 @@ streaming::Node broadcast( co_await ch_out->drain(ctx->executor()); } +namespace { + /** * @brief Join a table chunk against a build hash table returning a message of the result. * * @param ctx Streaming context - * @param right_chunk Chunk to join + * @param probe_chunk Chunk to join * @param sequence Sequence number of the output * @param joiner hash_join object, representing the build table. * @param build_carrier Columns from the build-side table to be included in the output. - * @param right_on Key column indiecs in `right_chunk`. + * The caller is responsible for selecting whether build-side key columns are present. + * @param probe_on Key column indices in `probe_chunk`. + * @param keep_keys Should the join keys be included in the output. + * @param broadcast_side Which table in the join was broadcasted (affects returned column + * ordering) * @param build_stream Stream the `joiner` will be deallocated on. * @param build_event Event recording the creation of the `joiner`. + * @param dealloc_event Event to use to stream-order deallocations. * * @return Message of `TableChunk` containing the result of the inner join. */ streaming::Message inner_join_chunk( - std::shared_ptr ctx, - streaming::TableChunk&& right_chunk, - std::uint64_t sequence, - cudf::hash_join& joiner, - cudf::table_view build_carrier, - std::vector right_on, - rmm::cuda_stream_view build_stream, - CudaEvent* build_event -) { - CudaEvent event; - right_chunk = to_device(ctx, std::move(right_chunk)); - auto chunk_stream = right_chunk.stream(); - build_event->stream_wait(chunk_stream); - auto probe_table = right_chunk.table_view(); - auto probe_keys = probe_table.select(right_on); - auto [probe_match, build_match] = - joiner.inner_join(probe_keys, std::nullopt, chunk_stream, ctx->br()->device_mr()); - - cudf::column_view build_indices = - cudf::device_span(*build_match); - cudf::column_view probe_indices = - cudf::device_span(*probe_match); - // build_carrier is valid on build_stream, but chunk_stream is - // waiting for build_stream work to be done, so running this on - // chunk_stream is fine. - auto result_columns = cudf::gather( - build_carrier, - build_indices, - cudf::out_of_bounds_policy::DONT_CHECK, - chunk_stream, - ctx->br()->device_mr() - ) - ->release(); - // drop key columns from probe table. - std::vector to_keep; - std::ranges::copy_if( - std::ranges::iota_view(0, probe_table.num_columns()), - std::back_inserter(to_keep), - [&](auto i) { return std::ranges::find(right_on, i) == right_on.end(); } - ); - std::ranges::move( - cudf::gather( - probe_table.select(to_keep), - probe_indices, - cudf::out_of_bounds_policy::DONT_CHECK, - chunk_stream, - ctx->br()->device_mr() - ) - ->release(), - std::back_inserter(result_columns) - ); - // Deallocation of the join indices will happen on build_stream, so add stream dep - // This also ensure deallocation of the hash_join object waits for completion. - cuda_stream_join(build_stream, chunk_stream, &event); - return streaming::to_message( - sequence, - std::make_unique( - std::make_unique(std::move(result_columns)), chunk_stream - ) - ); -} - -streaming::Message inner_join_chunk_probe_first( std::shared_ptr ctx, streaming::TableChunk&& probe_chunk, std::uint64_t sequence, cudf::hash_join& joiner, - cudf::table_view build_table, - std::vector build_keys, + cudf::table_view build_carrier, std::vector probe_keys, KeepKeys keep_keys, + BroadcastSide broadcast_side, rmm::cuda_stream_view build_stream, - CudaEvent* build_event + CudaEvent* build_event, + CudaEvent* dealloc_event ) { - CudaEvent event; probe_chunk = to_device(ctx, std::move(probe_chunk)); auto chunk_stream = probe_chunk.stream(); build_event->stream_wait(chunk_stream); @@ -249,52 +193,44 @@ streaming::Message inner_join_chunk_probe_first( probe_table_keys, std::nullopt, chunk_stream, ctx->br()->device_mr() ); - cudf::column_view build_indices = + cudf::column_view build_match_indices = cudf::device_span(*build_match); - cudf::column_view probe_indices = + cudf::column_view probe_match_indices = cudf::device_span(*probe_match); - cudf::table_view probe_carrier; - if (keep_keys == KeepKeys::YES) { - probe_carrier = probe_table; - } else { - std::vector probe_to_keep; - std::ranges::copy_if( - std::ranges::iota_view(0, probe_table.num_columns()), - std::back_inserter(probe_to_keep), - [&](auto i) { return std::ranges::find(probe_keys, i) == probe_keys.end(); } - ); - probe_carrier = probe_table.select(probe_to_keep); + std::vector> result_columns; + auto gather_columns = [&](cudf::table_view table, cudf::column_view indices) { + auto gathered = cudf::gather( + table, + indices, + cudf::out_of_bounds_policy::DONT_CHECK, + chunk_stream, + ctx->br()->device_mr() + ) + ->release(); + std::ranges::move(gathered, std::back_inserter(result_columns)); + }; + + auto const broadcast_left = broadcast_side == BroadcastSide::LEFT; + auto probe_carrier = probe_table; + if (keep_keys == KeepKeys::NO || broadcast_left) { + auto probe_indices = + std::views::iota(cudf::size_type{0}, probe_table.num_columns()) + | std::views::filter([&](auto i) { + return std::ranges::find(probe_keys, i) == probe_keys.end(); + }); + probe_carrier = probe_carrier.select(probe_indices.begin(), probe_indices.end()); } - auto result_columns = cudf::gather( - probe_carrier, - probe_indices, - cudf::out_of_bounds_policy::DONT_CHECK, - chunk_stream, - ctx->br()->device_mr() - ) - ->release(); - std::vector build_to_keep; - std::ranges::copy_if( - std::ranges::iota_view(0, build_table.num_columns()), - std::back_inserter(build_to_keep), - [&](auto i) { return std::ranges::find(build_keys, i) == build_keys.end(); } - ); - if (!build_to_keep.empty()) { - std::ranges::move( - cudf::gather( - build_table.select(build_to_keep), - build_indices, - cudf::out_of_bounds_policy::DONT_CHECK, - chunk_stream, - ctx->br()->device_mr() - ) - ->release(), - std::back_inserter(result_columns) - ); + if (broadcast_left) { + gather_columns(build_carrier, build_match_indices); + gather_columns(probe_carrier, probe_match_indices); + } else { + gather_columns(probe_carrier, probe_match_indices); + gather_columns(build_carrier, build_match_indices); } - cuda_stream_join(build_stream, chunk_stream, &event); + + cuda_stream_join(build_stream, chunk_stream, dealloc_event); return streaming::to_message( sequence, std::make_unique( @@ -302,6 +238,7 @@ streaming::Message inner_join_chunk_probe_first( ) ); } +} // namespace streaming::Node inner_join_broadcast( std::shared_ptr ctx, @@ -317,10 +254,11 @@ streaming::Node inner_join_broadcast( streaming::ShutdownAtExit c{left, right, ch_out}; co_await ctx->executor()->schedule(); ctx->comm()->logger().print("Inner broadcast join ", static_cast(tag)); - auto build = broadcast_side == BroadcastSide::LEFT ? left : right; - auto probe = broadcast_side == BroadcastSide::LEFT ? right : left; - auto build_keys = broadcast_side == BroadcastSide::LEFT ? left_on : right_on; - auto probe_keys = broadcast_side == BroadcastSide::LEFT ? right_on : left_on; + auto const broadcast_left = broadcast_side == BroadcastSide::LEFT; + auto build = broadcast_left ? left : right; + auto probe = broadcast_left ? right : left; + auto build_keys = broadcast_left ? left_on : right_on; + auto probe_keys = broadcast_left ? right_on : left_on; auto build_table = to_device( ctx, (co_await broadcast(ctx, build, tag, streaming::AllGather::Ordered::NO)) @@ -336,40 +274,16 @@ streaming::Node inner_join_broadcast( build_table.stream() ); CudaEvent build_event; + CudaEvent dealloc_event; build_event.record(build_table.stream()); - if (broadcast_side == BroadcastSide::RIGHT) { - while (!ch_out->is_shutdown()) { - auto probe_msg = co_await probe->receive(); - if (probe_msg.empty()) { - break; - } - co_await ch_out->send(inner_join_chunk_probe_first( - ctx, - probe_msg.release(), - probe_msg.sequence_number(), - joiner, - build_table.table_view(), - build_keys, - probe_keys, - keep_keys, - build_table.stream(), - &build_event - )); - } - co_await ch_out->drain(ctx->executor()); - co_return; - } - cudf::table_view build_carrier; - if (keep_keys == KeepKeys::YES) { - build_carrier = build_table.table_view(); - } else { - std::vector to_keep; - std::ranges::copy_if( - std::ranges::iota_view(0, build_table.table_view().num_columns()), - std::back_inserter(to_keep), - [&](auto i) { return std::ranges::find(build_keys, i) == build_keys.end(); } - ); - build_carrier = build_table.table_view().select(to_keep); + auto build_carrier = build_table.table_view(); + if (keep_keys == KeepKeys::NO || !broadcast_left) { + auto build_indices = + std::views::iota(cudf::size_type{0}, build_carrier.num_columns()) + | std::views::filter([&](auto i) { + return std::ranges::find(build_keys, i) == build_keys.end(); + }); + build_carrier = build_carrier.select(build_indices.begin(), build_indices.end()); } while (!ch_out->is_shutdown()) { auto probe_msg = co_await probe->receive(); @@ -383,8 +297,11 @@ streaming::Node inner_join_broadcast( joiner, build_carrier, probe_keys, + keep_keys, + broadcast_side, build_table.stream(), - &build_event + &build_event, + &dealloc_event )); } @@ -404,6 +321,7 @@ streaming::Node inner_join_shuffle( ctx->comm()->logger().print("Inner shuffle join"); co_await ctx->executor()->schedule(); CudaEvent build_event; + CudaEvent dealloc_event; while (!ch_out->is_shutdown()) { // Requirement: two shuffles kick out partitions in the same order auto left_msg = co_await left->receive(); @@ -427,17 +345,15 @@ streaming::Node inner_join_shuffle( build_stream ); build_event.record(build_stream); - cudf::table_view build_carrier; - if (keep_keys == KeepKeys::YES) { - build_carrier = build_chunk.table_view(); - } else { - std::vector to_keep; - std::ranges::copy_if( - std::ranges::iota_view(0, build_chunk.table_view().num_columns()), - std::back_inserter(to_keep), - [&](auto i) { return std::ranges::find(left_on, i) == left_on.end(); } - ); - build_carrier = build_chunk.table_view().select(to_keep); + auto build_carrier = build_chunk.table_view(); + if (keep_keys == KeepKeys::NO) { + auto build_indices = + std::views::iota(cudf::size_type{0}, build_carrier.num_columns()) + | std::views::filter([&](auto i) { + return std::ranges::find(left_on, i) == left_on.end(); + }); + build_carrier = + build_carrier.select(build_indices.begin(), build_indices.end()); } co_await ch_out->send(inner_join_chunk( ctx, @@ -446,8 +362,11 @@ streaming::Node inner_join_shuffle( joiner, build_carrier, right_on, + keep_keys, + BroadcastSide::LEFT, build_stream, - &build_event + &build_event, + &dealloc_event )); } co_await ch_out->drain(ctx->executor()); From 5f736069ea5d40001a83bff5ccb6d58afd909c1b Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 22 Jan 2026 12:18:56 +0000 Subject: [PATCH 10/14] Add send/receive metadata methods to Channel To enable adaptive algorithms in query operator nodes, we will often need some kind of metadata about the messages in a channel. Minimally we typically need the number of expected messages that are being sent. Rather than carrying a separate object around for metadata, give each Channel a metadata queue that can be pushed into by producers. We use a queue so that we can push in multiple messages without suspending on the produce side. --- .../rapidsmpf/streaming/core/channel.hpp | 25 +++++++++++++++++++ cpp/src/streaming/core/channel.cpp | 23 +++++++++++++++-- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/cpp/include/rapidsmpf/streaming/core/channel.hpp b/cpp/include/rapidsmpf/streaming/core/channel.hpp index 733126242..943087f1b 100644 --- a/cpp/include/rapidsmpf/streaming/core/channel.hpp +++ b/cpp/include/rapidsmpf/streaming/core/channel.hpp @@ -18,6 +18,7 @@ #include #include +#include #include namespace rapidsmpf::streaming { @@ -64,6 +65,29 @@ class Channel { */ coro::task receive(); + /** + * @brief Asynchronously send a metadata message into the channel. + * + * Suspends if the metadata queue is locked by another producer. + * + * @param msg The metadata message to send. + * @return A coroutine that evaluates to true if the msg was successfully sent or + * false if the channel was shut down. + * + * @throws std::logic_error If the message is empty. + */ + coro::task send_metadata(Message msg); + + /** + * @brief Asynchronously receive a metadata message from the channel. + * + * Suspends if the metadata queue is empty. + * + * @return A coroutine that evaluates to the message, which will be empty if the + * metadata queue is shut down. + */ + coro::task receive_metadata(); + /** * @brief Drains all pending messages from the channel and shuts it down. * @@ -103,6 +127,7 @@ class Channel { coro::ring_buffer rb_; std::shared_ptr sm_; + coro::queue metadata_; }; /** diff --git a/cpp/src/streaming/core/channel.cpp b/cpp/src/streaming/core/channel.cpp index dde57f138..245efc151 100644 --- a/cpp/src/streaming/core/channel.cpp +++ b/cpp/src/streaming/core/channel.cpp @@ -23,12 +23,31 @@ coro::task Channel::receive() { } } +coro::task Channel::send_metadata(Message msg) { + RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty"); + auto result = co_await metadata_.push(std::move(msg)); + co_return result == coro::queue_produce_result::produced; +} + +coro::task Channel::receive_metadata() { + auto msg = co_await metadata_.pop(); + if (msg.has_value()) { + co_return std::move(*msg); + } else { + co_return Message{}; + } +} + Node Channel::drain(std::shared_ptr executor) { - return rb_.shutdown_drain(executor->get()); + coro_results( + co_await coro::when_all( + rb_.shutdown_drain(executor->get()), metadata_.shutdown_drain(executor->get()) + ) + ); } Node Channel::shutdown() { - return rb_.shutdown(); + coro_results(co_await coro::when_all(metadata_.shutdown(), rb_.shutdown())); } bool Channel::empty() const noexcept { From 697d3108a7bf5cff2891ff38088a7a2b8899d09d Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 22 Jan 2026 12:33:11 +0000 Subject: [PATCH 11/14] Add ability to shutdown metadata independently of message channel --- .../rapidsmpf/streaming/core/channel.hpp | 20 +++++++++++++++++++ cpp/src/streaming/core/channel.cpp | 10 +++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/cpp/include/rapidsmpf/streaming/core/channel.hpp b/cpp/include/rapidsmpf/streaming/core/channel.hpp index 943087f1b..3ca99a052 100644 --- a/cpp/include/rapidsmpf/streaming/core/channel.hpp +++ b/cpp/include/rapidsmpf/streaming/core/channel.hpp @@ -88,6 +88,17 @@ class Channel { */ coro::task receive_metadata(); + /** + * @brief Drains all pending metadata messages from the channel and shuts down the + * metadata channel. + * + * This is intended to ensure all remaining metadata messages are processed. + * + * @param executor The thread pool used to process remaining messages. + * @return A coroutine representing the completion of the metadata shutdown drain. + */ + [[nodiscard]] Node drain_metadata(std::shared_ptr executor); + /** * @brief Drains all pending messages from the channel and shuts it down. * @@ -107,6 +118,15 @@ class Channel { */ Node shutdown(); + /** + * @brief Immediately shuts down the metadata queue. + * + * Any pending or future metadata send/receive operations will complete with failure. + * + * @return A coroutine representing the completion of the shutdown. + */ + Node shutdown_metadata(); + /** * @brief Check whether the channel is empty. * diff --git a/cpp/src/streaming/core/channel.cpp b/cpp/src/streaming/core/channel.cpp index 245efc151..1fdf3e02a 100644 --- a/cpp/src/streaming/core/channel.cpp +++ b/cpp/src/streaming/core/channel.cpp @@ -29,6 +29,10 @@ coro::task Channel::send_metadata(Message msg) { co_return result == coro::queue_produce_result::produced; } +Node Channel::drain_metadata(std::shared_ptr executor) { + return metadata_.shutdown_drain(executor->get()); +} + coro::task Channel::receive_metadata() { auto msg = co_await metadata_.pop(); if (msg.has_value()) { @@ -41,7 +45,7 @@ coro::task Channel::receive_metadata() { Node Channel::drain(std::shared_ptr executor) { coro_results( co_await coro::when_all( - rb_.shutdown_drain(executor->get()), metadata_.shutdown_drain(executor->get()) + rb_.shutdown_drain(executor->get()), drain_metadata(executor) ) ); } @@ -50,6 +54,10 @@ Node Channel::shutdown() { coro_results(co_await coro::when_all(metadata_.shutdown(), rb_.shutdown())); } +Node Channel::shutdown_metadata() { + return metadata_.shutdown(); +} + bool Channel::empty() const noexcept { return rb_.empty(); } From 7379dd3498823af900dae383cdf77db3271d0239 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 22 Jan 2026 12:34:29 +0000 Subject: [PATCH 12/14] Add nodiscard to Channel coroutines --- cpp/include/rapidsmpf/streaming/core/channel.hpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cpp/include/rapidsmpf/streaming/core/channel.hpp b/cpp/include/rapidsmpf/streaming/core/channel.hpp index 3ca99a052..959a154a7 100644 --- a/cpp/include/rapidsmpf/streaming/core/channel.hpp +++ b/cpp/include/rapidsmpf/streaming/core/channel.hpp @@ -51,7 +51,7 @@ class Channel { * * @throws std::logic_error If the message is empty. */ - coro::task send(Message msg); + [[nodiscard]] coro::task send(Message msg); /** * @brief Asynchronously receive a message from the channel. @@ -63,7 +63,7 @@ class Channel { * * @throws std::logic_error If the received message is empty. */ - coro::task receive(); + [[nodiscard]] coro::task receive(); /** * @brief Asynchronously send a metadata message into the channel. @@ -76,7 +76,7 @@ class Channel { * * @throws std::logic_error If the message is empty. */ - coro::task send_metadata(Message msg); + [[nodiscard]] coro::task send_metadata(Message msg); /** * @brief Asynchronously receive a metadata message from the channel. @@ -86,7 +86,7 @@ class Channel { * @return A coroutine that evaluates to the message, which will be empty if the * metadata queue is shut down. */ - coro::task receive_metadata(); + [[nodiscard]] coro::task receive_metadata(); /** * @brief Drains all pending metadata messages from the channel and shuts down the @@ -107,7 +107,7 @@ class Channel { * @param executor The thread pool used to process remaining messages. * @return A coroutine representing the completion of the shutdown drain. */ - Node drain(std::shared_ptr executor); + [[nodiscard]] Node drain(std::shared_ptr executor); /** * @brief Immediately shuts down the channel. @@ -116,7 +116,7 @@ class Channel { * * @return A coroutine representing the completion of the shutdown. */ - Node shutdown(); + [[nodiscard]] Node shutdown(); /** * @brief Immediately shuts down the metadata queue. @@ -125,7 +125,7 @@ class Channel { * * @return A coroutine representing the completion of the shutdown. */ - Node shutdown_metadata(); + [[nodiscard]] Node shutdown_metadata(); /** * @brief Check whether the channel is empty. From e500302c491ade3ebccdc26a7a09950913519425 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 22 Jan 2026 12:54:02 +0000 Subject: [PATCH 13/14] Add basic channel tests --- cpp/tests/CMakeLists.txt | 1 + cpp/tests/streaming/test_channel.cpp | 98 ++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 cpp/tests/streaming/test_channel.cpp diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index d695ba3a2..388c81476 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -98,6 +98,7 @@ if(RAPIDSMPF_HAVE_STREAMING) target_sources( test_sources PRIVATE streaming/test_allgather.cpp + streaming/test_channel.cpp streaming/test_error_handling.cpp streaming/test_fanout.cpp streaming/test_leaf_node.cpp diff --git a/cpp/tests/streaming/test_channel.cpp b/cpp/tests/streaming/test_channel.cpp new file mode 100644 index 000000000..27e94b99d --- /dev/null +++ b/cpp/tests/streaming/test_channel.cpp @@ -0,0 +1,98 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include + +#include +#include +#include +#include + +#include "base_streaming_fixture.hpp" + +using namespace rapidsmpf::streaming; + +namespace { +std::vector make_int_messages(std::size_t n) { + std::vector messages; + messages.reserve(n); + for (std::size_t i = 0; i < n; ++i) { + messages.emplace_back( + i, std::make_unique(i), rapidsmpf::ContentDescription{} + ); + } + return messages; +} +} // namespace + +using BaseStreamingChannel = BaseStreamingFixture; + +TEST_F(BaseStreamingChannel, DataRoundTripWithoutMetadata) { + auto ch = ctx->create_channel(); + std::vector outputs; + std::vector nodes; + static constexpr std::size_t num_messages = 4; + nodes.emplace_back(node::push_to_channel(ctx, ch, make_int_messages(num_messages))); + nodes.emplace_back(node::pull_from_channel(ctx, ch, outputs)); + run_streaming_pipeline(std::move(nodes)); + + ASSERT_EQ(outputs.size(), num_messages); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(outputs[static_cast(i)].release(), i); + } +} + +TEST_F(BaseStreamingChannel, MetadataSendReceiveAndShutdown) { + auto ch = ctx->create_channel(); + std::vector metadata_outputs; + std::vector data_outputs; + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + co_await ch->send_metadata(Message{0, std::make_unique(10), {}}); + co_await ch->send_metadata(Message{1, std::make_unique(20), {}}); + co_await ch->shutdown_metadata(); + + co_await ch->send(Message{0, std::make_unique(1), {}}); + co_await ch->send(Message{1, std::make_unique(2), {}}); + co_await ch->drain(ctx->executor()); + }; + + auto consumer = [this, ch, &metadata_outputs, &data_outputs]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + while (true) { + auto msg = co_await ch->receive_metadata(); + if (msg.empty()) { + break; + } + metadata_outputs.push_back(std::move(msg)); + } + + while (true) { + auto msg = co_await ch->receive(); + if (msg.empty()) { + break; + } + data_outputs.push_back(std::move(msg)); + } + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + run_streaming_pipeline(std::move(nodes)); + + ASSERT_EQ(metadata_outputs.size(), 2U); + EXPECT_EQ(metadata_outputs[0].get(), 10); + EXPECT_EQ(metadata_outputs[1].get(), 20); + + ASSERT_EQ(data_outputs.size(), 2U); + EXPECT_EQ(data_outputs[0].get(), 1); + EXPECT_EQ(data_outputs[1].get(), 2); +} From 76bdb97d282ded99040bd75515c5d0eb252a91ee Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 22 Jan 2026 15:34:59 +0000 Subject: [PATCH 14/14] Add more metadata tests of channel --- cpp/tests/streaming/test_channel.cpp | 235 ++++++++++++++++++++++++++- 1 file changed, 228 insertions(+), 7 deletions(-) diff --git a/cpp/tests/streaming/test_channel.cpp b/cpp/tests/streaming/test_channel.cpp index 27e94b99d..54a89fa1d 100644 --- a/cpp/tests/streaming/test_channel.cpp +++ b/cpp/tests/streaming/test_channel.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -54,13 +55,17 @@ TEST_F(BaseStreamingChannel, MetadataSendReceiveAndShutdown) { ShutdownAtExit c{ch}; co_await ctx->executor()->schedule(); - co_await ch->send_metadata(Message{0, std::make_unique(10), {}}); - co_await ch->send_metadata(Message{1, std::make_unique(20), {}}); - co_await ch->shutdown_metadata(); - - co_await ch->send(Message{0, std::make_unique(1), {}}); - co_await ch->send(Message{1, std::make_unique(2), {}}); - co_await ch->drain(ctx->executor()); + auto meta_task = [&]() -> Node { + co_await ch->send_metadata(Message{0, std::make_unique(10), {}}); + co_await ch->send_metadata(Message{1, std::make_unique(20), {}}); + co_await ch->drain_metadata(ctx->executor()); + }; + auto send_task = [&]() -> Node { + co_await ch->send(Message{0, std::make_unique(1), {}}); + co_await ch->send(Message{1, std::make_unique(2), {}}); + co_await ch->drain(ctx->executor()); + }; + coro_results(co_await coro::when_all(meta_task(), send_task())); }; auto consumer = [this, ch, &metadata_outputs, &data_outputs]() -> Node { @@ -96,3 +101,219 @@ TEST_F(BaseStreamingChannel, MetadataSendReceiveAndShutdown) { EXPECT_EQ(data_outputs[0].get(), 1); EXPECT_EQ(data_outputs[1].get(), 2); } + +TEST_F(BaseStreamingChannel, DataOnlyWithMetadataShutdown) { + auto ch = ctx->create_channel(); + std::vector data_outputs; + std::vector metadata_outputs; + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + co_await ch->shutdown_metadata(); + co_await ch->send( + Message{0, std::make_unique(10), rapidsmpf::ContentDescription{}} + ); + co_await ch->send( + Message{1, std::make_unique(20), rapidsmpf::ContentDescription{}} + ); + co_await ch->drain(ctx->executor()); + }; + + auto consumer = [this, ch, &metadata_outputs, &data_outputs]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + while (true) { + auto msg = co_await ch->receive_metadata(); + if (msg.empty()) { + break; + } + metadata_outputs.push_back(std::move(msg)); + } + + while (true) { + auto msg = co_await ch->receive(); + if (msg.empty()) { + break; + } + data_outputs.push_back(std::move(msg)); + } + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + run_streaming_pipeline(std::move(nodes)); + + EXPECT_TRUE(metadata_outputs.empty()); + ASSERT_EQ(data_outputs.size(), 2U); + EXPECT_EQ(data_outputs[0].get(), 10); + EXPECT_EQ(data_outputs[1].get(), 20); +} + +TEST_F(BaseStreamingChannel, MetadataOnlyWithDataShutdown) { + auto ch = ctx->create_channel(); + std::vector metadata_outputs; + std::vector data_outputs; + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + co_await ch->send_metadata( + Message{0, std::make_unique(10), rapidsmpf::ContentDescription{}} + ); + co_await ch->send_metadata( + Message{1, std::make_unique(20), rapidsmpf::ContentDescription{}} + ); + co_await ch->drain(ctx->executor()); + }; + + auto consumer = [this, ch, &metadata_outputs, &data_outputs]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + while (true) { + auto msg = co_await ch->receive_metadata(); + if (msg.empty()) { + break; + } + metadata_outputs.push_back(std::move(msg)); + } + + while (true) { + auto msg = co_await ch->receive(); + if (msg.empty()) { + break; + } + data_outputs.push_back(std::move(msg)); + } + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + run_streaming_pipeline(std::move(nodes)); + + ASSERT_EQ(metadata_outputs.size(), 2U); + EXPECT_EQ(metadata_outputs[0].get(), 10); + EXPECT_EQ(metadata_outputs[1].get(), 20); + EXPECT_TRUE(data_outputs.empty()); +} + +TEST_F(BaseStreamingChannel, ConsumerIgnoresMetadata) { + auto ch = ctx->create_channel(); + std::vector data_outputs; + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + co_await ch->send_metadata( + Message{0, std::make_unique(10), rapidsmpf::ContentDescription{}} + ); + co_await ch->send_metadata( + Message{0, std::make_unique(20), rapidsmpf::ContentDescription{}} + ); + co_await ch->send( + Message{1, std::make_unique(30), rapidsmpf::ContentDescription{}} + ); + co_await ch->drain(ctx->executor()); + }; + + auto consumer = [this, ch, &data_outputs]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + while (true) { + auto msg = co_await ch->receive(); + if (msg.empty()) { + break; + } + data_outputs.push_back(std::move(msg)); + } + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + run_streaming_pipeline(std::move(nodes)); + + EXPECT_EQ(data_outputs.size(), 1U); + EXPECT_EQ(data_outputs[0].get(), 30); +} + +TEST_F(BaseStreamingChannel, ProducerThrowsWithMetadata) { + auto ch = ctx->create_channel(); + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + co_await ch->send_metadata( + Message{0, std::make_unique(31), rapidsmpf::ContentDescription{}} + ); + throw std::runtime_error("producer failed"); + }; + + auto consumer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + while (true) { + auto msg = co_await ch->receive_metadata(); + if (msg.empty()) { + break; + } + } + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + EXPECT_THROW(run_streaming_pipeline(std::move(nodes)), std::runtime_error); +} + +TEST_F(BaseStreamingChannel, ConsumerThrowsWithMetadata) { + auto ch = ctx->create_channel(); + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + co_await ch->send_metadata( + Message{0, std::make_unique(10), rapidsmpf::ContentDescription{}} + ); + co_await ch->drain(ctx->executor()); + }; + + auto consumer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + throw std::runtime_error("consumer failed"); + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + EXPECT_THROW(run_streaming_pipeline(std::move(nodes)), std::runtime_error); +} + +TEST_F(BaseStreamingChannel, ProducerAndConsumerThrow) { + auto ch = ctx->create_channel(); + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + throw std::runtime_error("producer failed"); + }; + + auto consumer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + throw std::runtime_error("consumer failed"); + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + EXPECT_THROW(run_streaming_pipeline(std::move(nodes)), std::runtime_error); +}