diff --git a/cpp/benchmarks/streaming/ndsh/CMakeLists.txt b/cpp/benchmarks/streaming/ndsh/CMakeLists.txt index 5b9f2dc6e..cc5b453fc 100644 --- a/cpp/benchmarks/streaming/ndsh/CMakeLists.txt +++ b/cpp/benchmarks/streaming/ndsh/CMakeLists.txt @@ -41,7 +41,7 @@ target_link_libraries( $ maybe_asan ) -set(RAPIDSMPFNDSH_QUERIES q01 q03 q09 q21 bench_read) +set(RAPIDSMPFNDSH_QUERIES q01 q03 q09 q21 bench_read adaptive_join) foreach(query IN ITEMS ${RAPIDSMPFNDSH_QUERIES}) add_executable(${query} "${query}.cpp") diff --git a/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp b/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp new file mode 100644 index 000000000..eb4c24be8 --- /dev/null +++ b/cpp/benchmarks/streaming/ndsh/adaptive_join.cpp @@ -0,0 +1,553 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "join.hpp" +#include "utils.hpp" + +#include + +namespace { + +rapidsmpf::streaming::Node read_parquet( + std::shared_ptr ctx, + std::shared_ptr ch_out, + std::size_t num_producers, + cudf::size_type num_rows_per_chunk, + std::optional> columns, + std::string const& input_directory, + std::string const& input_file +) { + auto files = rapidsmpf::ndsh::detail::list_parquet_files( + rapidsmpf::ndsh::detail::get_table_path(input_directory, input_file) + ); + auto options = + cudf::io::parquet_reader_options::builder(cudf::io::source_info(files)).build(); + if (columns.has_value()) { + options.set_columns(*columns); + } + return rapidsmpf::streaming::node::read_parquet( + ctx, ch_out, num_producers, options, num_rows_per_chunk + ); +} + +[[maybe_unused]] std::size_t estimate_read_parquet_messages( + std::shared_ptr ctx, + std::string const& input_directory, + std::string const& input_file, + cudf::size_type num_rows_per_chunk +) { + auto files = rapidsmpf::ndsh::detail::list_parquet_files( + rapidsmpf::ndsh::detail::get_table_path(input_directory, input_file) + ); + if (files.empty()) { + return 0; + } + + // Assumption: this file-to-rank mapping matches read_parquet in the streaming node. + auto const rank = static_cast(ctx->comm()->rank()); + auto const size = static_cast(ctx->comm()->nranks()); + auto const base = files.size() / size; + auto const extra = files.size() % size; + auto const files_per_rank = base + (rank < extra ? 1 : 0); + auto const file_offset = rank * base + std::min(rank, extra); + if (files_per_rank == 0) { + return 0; + } + + std::size_t total_rows = 0; + for (std::size_t i = 0; i < files_per_rank; ++i) { + auto const& file = files[file_offset + i]; + total_rows += static_cast( + cudf::io::read_parquet_metadata(cudf::io::source_info(file)).num_rows() + ); + } + if (total_rows == 0 || num_rows_per_chunk <= 0) { + return 0; + } + + // Assumption: chunk sizes are close to num_rows_per_chunk and filters are absent. + auto const chunk_rows = static_cast(num_rows_per_chunk); + return (total_rows + chunk_rows - 1) / chunk_rows; +} + +[[maybe_unused]] rapidsmpf::streaming::Node advertise_message_count( + std::shared_ptr ctx, + std::shared_ptr ch_meta, + std::size_t estimate +) { + rapidsmpf::streaming::ShutdownAtExit c{ch_meta}; + co_await ctx->executor()->schedule(); + auto payload = std::make_unique(estimate); + co_await ch_meta->send(rapidsmpf::streaming::Message{0, std::move(payload), {}, {}}); + co_await ch_meta->drain(ctx->executor()); + ctx->comm()->logger().print("Exiting message count"); +} + +[[maybe_unused]] rapidsmpf::streaming::Node consume_channel_parallel( + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::size_t +) { + rapidsmpf::streaming::ShutdownAtExit c{ch_in}; + std::size_t estimated_total_bytes{0}; + co_await ctx->executor()->schedule(); + while (true) { + auto msg = co_await ch_in->receive(); + if (msg.empty()) { + break; + } + if (msg.holds()) { + auto chunk = rapidsmpf::ndsh::to_device( + ctx, msg.release() + ); + ctx->comm()->logger().print( + "Consumed chunk ", + msg.sequence_number(), + " with ", + chunk.table_view().num_rows(), + " rows and ", + chunk.table_view().num_columns(), + " columns" + ); + estimated_total_bytes += chunk.data_alloc_size(rapidsmpf::MemoryType::DEVICE); + } + } + ctx->comm()->logger().print( + "Table was around ", rmm::detail::format_bytes(estimated_total_bytes) + ); +} + +///< @brief Configuration options for the benchmark +struct ProgramOptions { + int num_streaming_threads{1}; ///< Number of streaming threads to use + int num_iterations{2}; ///< Number of iterations of query to run + int num_streams{16}; ///< Number of streams in stream pool + rapidsmpf::ndsh::CommType comm_type{ + rapidsmpf::ndsh::CommType::UCXX + }; ///< Type of communicator to create + cudf::size_type num_rows_per_chunk{ + 100'000'000 + }; ///< Number of rows to produce per chunk read + std::size_t num_producers{ + 1 + }; ///< Number of simultaneous read_parquet chunk producers. + std::size_t num_consumers{1}; ///< Number of simultaneous chunk consumers. + std::string input_directory; ///< Directory containing input files. + std::string left_input_file; ///< Basename of left input file to read. + std::string right_input_file; ///< Basename of right input file to read. + std::optional> left_columns{ + std::nullopt + }; ///< Columns to read (left input). + std::optional> right_columns{ + std::nullopt + }; ///< Columns to read (right input). +}; + +ProgramOptions parse_arguments(int argc, char** argv) { + ProgramOptions options; + + static constexpr std:: + array(rapidsmpf::ndsh::CommType::MAX)> + comm_names{"single", "mpi", "ucxx"}; + + auto print_usage = [&argv, &options]() { + std::cerr + << "Usage: " << argv[0] << " [options]\n" + << "Options:\n" + << " --num-streaming-threads Number of streaming threads (default: " + << options.num_streaming_threads << ")\n" + << " --num-iterations Number of iterations (default: " + << options.num_iterations << ")\n" + << " --num-streams Number of streams in stream pool " + "(default: " + << options.num_streams << ")\n" + << " --num-rows-per-chunk Number of rows per chunk (default: " + << options.num_rows_per_chunk << ")\n" + << " --num-producers Number of concurrent read_parquet " + "producers (default: " + << options.num_producers << ")\n" + << " --num-consumers Number of concurrent consumers (default: " + << options.num_consumers << ")\n" + << " --comm-type Communicator type: single, mpi, ucxx " + "(default: " + << comm_names[static_cast(options.comm_type)] << ")\n" + << " --input-directory Input directory path (required)\n" + << " --left-input-file Left input file basename relative to " + "input " + "directory (required)\n" + << " --right-input-file Right input file basename relative to " + "input directory (required)\n" + << " --left-columns Comma-separated column names to read " + "(optional, default all columns)\n" + << " --right-columns Comma-separated column names to read " + "(optional, default all columns)\n" + << " --help Show this help message\n"; + }; + + // NOLINTBEGIN(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,modernize-use-designated-initializers) + static struct option long_options[] = { + {"num-streaming-threads", required_argument, nullptr, 1}, + {"num-rows-per-chunk", required_argument, nullptr, 2}, + {"num-producers", required_argument, nullptr, 3}, + {"num-consumers", required_argument, nullptr, 4}, + {"input-directory", required_argument, nullptr, 5}, + {"left-input-file", required_argument, nullptr, 6}, + {"right-input-file", required_argument, nullptr, 12}, + {"help", no_argument, nullptr, 7}, + {"num-iterations", required_argument, nullptr, 8}, + {"num-streams", required_argument, nullptr, 9}, + {"comm-type", required_argument, nullptr, 10}, + {"left-columns", required_argument, nullptr, 11}, + {"right-columns", required_argument, nullptr, 13}, + {nullptr, 0, nullptr, 0} + }; + // NOLINTEND(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,modernize-use-designated-initializers) + + int opt; + int option_index = 0; + + bool saw_input_directory = false; + bool saw_left_input_file = false; + bool saw_right_input_file = false; + + auto parse_i64 = [](char const* s, char const* opt_name) -> long long { + if (s == nullptr || *s == '\0') { + std::cerr << "Error: " << opt_name << " requires a value\n"; + std::exit(1); + } + errno = 0; + char* end = nullptr; + auto const v = std::strtoll(s, &end, 10); + if (errno != 0 || end == s || *end != '\0') { + std::cerr << "Error: invalid integer for " << opt_name << ": '" << s << "'\n"; + std::exit(1); + } + return v; + }; + + auto parse_u64 = [](char const* s, char const* opt_name) -> unsigned long long { + if (s == nullptr || *s == '\0') { + std::cerr << "Error: " << opt_name << " requires a value\n"; + std::exit(1); + } + errno = 0; + char* end = nullptr; + auto const v = std::strtoull(s, &end, 10); + if (errno != 0 || end == s || *end != '\0') { + std::cerr << "Error: invalid non-negative integer for " << opt_name << ": '" + << s << "'\n"; + std::exit(1); + } + return v; + }; + + auto require_positive_i32 = [&](char const* s, char const* opt_name) -> int { + auto const v = parse_i64(s, opt_name); + if (v <= 0 || v > std::numeric_limits::max()) { + std::cerr << "Error: " << opt_name << " must be in [1, " + << std::numeric_limits::max() << "], got '" << s << "'\n"; + std::exit(1); + } + return static_cast(v); + }; + + auto require_positive_size_t = [&](char const* s, + char const* opt_name) -> std::size_t { + auto const v = parse_u64(s, opt_name); + if (v == 0 || v > std::numeric_limits::max()) { + std::cerr << "Error: " << opt_name << " must be in [1, " + << std::numeric_limits::max() << "], got '" << s + << "'\n"; + std::exit(1); + } + return static_cast(v); + }; + + auto parse_columns = [](char const* s) -> std::optional> { + if (s == nullptr) { + return std::nullopt; + } + std::string str{s}; + if (str.empty()) { + return std::nullopt; + } + std::vector cols; + std::size_t start = 0; + while (start <= str.size()) { + auto const comma = str.find(',', start); + auto const end = (comma == std::string::npos) ? str.size() : comma; + auto const token = str.substr(start, end - start); + if (token.empty()) { + std::cerr << "Error: --columns contains an empty column name\n"; + std::exit(1); + } + cols.push_back(token); + if (comma == std::string::npos) { + break; + } + start = comma + 1; + } + return cols; + }; + + while ((opt = getopt_long(argc, argv, "", long_options, &option_index)) != -1) { + switch (opt) { + case 1: // --num-streaming-threads + options.num_streaming_threads = + require_positive_i32(optarg, "--num-streaming-threads"); + break; + case 2: // --num-rows-per-chunk + options.num_rows_per_chunk = + require_positive_i32(optarg, "--num-rows-per-chunk"); + break; + case 3: // --num-producers + options.num_producers = require_positive_size_t(optarg, "--num-producers"); + break; + case 4: // --num-consumers + options.num_consumers = require_positive_size_t(optarg, "--num-consumers"); + break; + case 5: // --input-directory + if (optarg == nullptr || *optarg == '\0') { + std::cerr << "Error: --input-directory requires a non-empty value\n"; + std::exit(1); + } + options.input_directory = optarg; + saw_input_directory = true; + break; + case 6: // --left-input-file + if (optarg == nullptr || *optarg == '\0') { + std::cerr << "Error: --left-input-file requires a non-empty value\n"; + std::exit(1); + } + options.left_input_file = optarg; + saw_left_input_file = true; + break; + case 12: // --right-input-file + if (optarg == nullptr || *optarg == '\0') { + std::cerr << "Error: --right-input-file requires a non-empty value\n"; + std::exit(1); + } + options.right_input_file = optarg; + saw_right_input_file = true; + break; + case 7: // --help + print_usage(); + std::exit(0); + case 8: // --num-iterations + options.num_iterations = require_positive_i32(optarg, "--num-iterations"); + break; + case 9: // --num-streams + options.num_streams = require_positive_i32(optarg, "--num-streams"); + break; + case 10: + { // --comm-type + if (optarg == nullptr || *optarg == '\0') { + std::cerr << "Error: --comm-type requires a value\n"; + std::exit(1); + } + std::string_view const s{optarg}; + auto parsed = std::optional{}; + for (std::size_t i = 0; i < comm_names.size(); ++i) { + if (s == comm_names[i]) { + parsed = static_cast(i); + break; + } + } + if (!parsed.has_value()) { + std::cerr << "Error: invalid --comm-type '" << s + << "' (expected: single, mpi, ucxx)\n"; + std::exit(1); + } + options.comm_type = *parsed; + break; + } + case 11: // --left-columns + options.left_columns = parse_columns(optarg); + break; + case 13: // --right-columns + options.right_columns = parse_columns(optarg); + break; + case '?': + if (optopt == 0 && optind > 1) { + std::cerr << "Error: Unknown option '" << argv[optind - 1] << "'\n\n"; + } + print_usage(); + std::exit(1); + default: + print_usage(); + std::exit(1); + } + } + + // Check if required options were provided + if (!saw_input_directory || !saw_left_input_file || !saw_right_input_file) { + if (!saw_input_directory) { + std::cerr << "Error: --input-directory is required\n"; + } + if (!saw_left_input_file) { + std::cerr << "Error: --left-input-file is required\n"; + } + if (!saw_right_input_file) { + std::cerr << "Error: --right-input-file is required\n"; + } + std::cerr << std::endl; + print_usage(); + std::exit(1); + } + + return options; +} + +} // namespace + +/** + * @brief Run a simple benchmark reading a table from parquet files. + */ +int main(int argc, char** argv) { + rapidsmpf::ndsh::FinalizeMPI finalize{}; + cudaFree(nullptr); + // work around https://github.com/rapidsai/cudf/issues/20849 + cudf::initialize(); + auto mr = rmm::mr::cuda_async_memory_resource{}; + auto stats_wrapper = rapidsmpf::RmmResourceAdaptor(&mr); + auto arguments = parse_arguments(argc, argv); + rapidsmpf::ndsh::ProgramOptions ctx_arguments{ + .num_streaming_threads = arguments.num_streaming_threads, + .num_iterations = arguments.num_iterations, + .num_streams = arguments.num_streams, + .comm_type = arguments.comm_type, + .num_rows_per_chunk = arguments.num_rows_per_chunk, + .output_file = "", + .input_directory = arguments.input_directory + }; + + auto ctx = rapidsmpf::ndsh::create_context(ctx_arguments, &stats_wrapper); + std::vector timings; + for (int i = 0; i < arguments.num_iterations; i++) { + std::vector nodes; + int op_id = 0; + auto start = std::chrono::steady_clock::now(); + { + RAPIDSMPF_NVTX_SCOPED_RANGE("Constructing read_parquet pipeline"); + + // Input data channels + auto left_out = ctx->create_channel(); + auto right_out = ctx->create_channel(); + auto left_meta = ctx->create_channel(); + auto right_meta = ctx->create_channel(); + auto left_estimate = estimate_read_parquet_messages( + ctx, + arguments.input_directory, + arguments.left_input_file, + arguments.num_rows_per_chunk + ); + auto right_estimate = estimate_read_parquet_messages( + ctx, + arguments.input_directory, + arguments.right_input_file, + arguments.num_rows_per_chunk + ); + nodes.push_back(read_parquet( + ctx, + left_out, + arguments.num_producers, + arguments.num_rows_per_chunk, + arguments.left_columns, + arguments.input_directory, + arguments.left_input_file + )); + nodes.push_back(read_parquet( + ctx, + right_out, + arguments.num_producers, + arguments.num_rows_per_chunk, + arguments.right_columns, + arguments.input_directory, + arguments.right_input_file + )); + nodes.push_back(advertise_message_count(ctx, left_meta, left_estimate)); + nodes.push_back(advertise_message_count(ctx, right_meta, right_estimate)); + auto joined = ctx->create_channel(); + auto const size_tag = static_cast(10 * i) + op_id++; + auto const left_shuffle_tag = static_cast(10 * i) + op_id++; + auto const right_shuffle_tag = static_cast(10 * i) + op_id++; + nodes.push_back( + rapidsmpf::ndsh::adaptive_inner_join( + ctx, + left_out, + right_out, + left_meta, + right_meta, + joined, + {0}, + {0}, + size_tag, + left_shuffle_tag, + right_shuffle_tag + ) + ); + nodes.push_back( + consume_channel_parallel(ctx, joined, arguments.num_consumers) + ); + } + auto end = std::chrono::steady_clock::now(); + std::chrono::duration pipeline = end - start; + start = std::chrono::steady_clock::now(); + { + RAPIDSMPF_NVTX_SCOPED_RANGE("read_parquet iteration"); + rapidsmpf::streaming::run_streaming_pipeline(std::move(nodes)); + } + end = std::chrono::steady_clock::now(); + std::chrono::duration compute = end - start; + timings.push_back(pipeline.count()); + timings.push_back(compute.count()); + ctx->comm()->logger().print(ctx->statistics()->report()); + ctx->statistics()->clear(); + } + + if (ctx->comm()->rank() == 0) { + for (int i = 0; i < arguments.num_iterations; i++) { + ctx->comm()->logger().print( + "Iteration ", + i, + " pipeline construction time [s]: ", + timings[size_t(2 * i)] + ); + ctx->comm()->logger().print( + "Iteration ", i, " compute time [s]: ", timings[size_t(2 * i + 1)] + ); + } + } + return 0; +} diff --git a/cpp/benchmarks/streaming/ndsh/join.cpp b/cpp/benchmarks/streaming/ndsh/join.cpp index 52b5f8f6c..15640c3f0 100644 --- a/cpp/benchmarks/streaming/ndsh/join.cpp +++ b/cpp/benchmarks/streaming/ndsh/join.cpp @@ -6,7 +6,11 @@ #include "join.hpp" #include +#include +#include #include +#include +#include #include #include @@ -25,12 +29,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include #include @@ -144,75 +150,87 @@ streaming::Node broadcast( co_await ch_out->drain(ctx->executor()); } +namespace { + /** * @brief Join a table chunk against a build hash table returning a message of the result. * * @param ctx Streaming context - * @param right_chunk Chunk to join + * @param probe_chunk Chunk to join * @param sequence Sequence number of the output * @param joiner hash_join object, representing the build table. * @param build_carrier Columns from the build-side table to be included in the output. - * @param right_on Key column indiecs in `right_chunk`. + * The caller is responsible for selecting whether build-side key columns are present. + * @param probe_on Key column indices in `probe_chunk`. + * @param keep_keys Should the join keys be included in the output. + * @param broadcast_side Which table in the join was broadcasted (affects returned column + * ordering) * @param build_stream Stream the `joiner` will be deallocated on. * @param build_event Event recording the creation of the `joiner`. + * @param dealloc_event Event to use to stream-order deallocations. * * @return Message of `TableChunk` containing the result of the inner join. */ streaming::Message inner_join_chunk( std::shared_ptr ctx, - streaming::TableChunk&& right_chunk, + streaming::TableChunk&& probe_chunk, std::uint64_t sequence, cudf::hash_join& joiner, cudf::table_view build_carrier, - std::vector right_on, + std::vector probe_keys, + KeepKeys keep_keys, + BroadcastSide broadcast_side, rmm::cuda_stream_view build_stream, - CudaEvent* build_event + CudaEvent* build_event, + CudaEvent* dealloc_event ) { - CudaEvent event; - right_chunk = to_device(ctx, std::move(right_chunk)); - auto chunk_stream = right_chunk.stream(); + probe_chunk = to_device(ctx, std::move(probe_chunk)); + auto chunk_stream = probe_chunk.stream(); build_event->stream_wait(chunk_stream); - auto probe_table = right_chunk.table_view(); - auto probe_keys = probe_table.select(right_on); - auto [probe_match, build_match] = - joiner.inner_join(probe_keys, std::nullopt, chunk_stream, ctx->br()->device_mr()); + auto probe_table = probe_chunk.table_view(); + auto probe_table_keys = probe_table.select(probe_keys); + auto [probe_match, build_match] = joiner.inner_join( + probe_table_keys, std::nullopt, chunk_stream, ctx->br()->device_mr() + ); - cudf::column_view build_indices = + cudf::column_view build_match_indices = cudf::device_span(*build_match); - cudf::column_view probe_indices = + cudf::column_view probe_match_indices = cudf::device_span(*probe_match); - // build_carrier is valid on build_stream, but chunk_stream is - // waiting for build_stream work to be done, so running this on - // chunk_stream is fine. - auto result_columns = cudf::gather( - build_carrier, - build_indices, - cudf::out_of_bounds_policy::DONT_CHECK, - chunk_stream, - ctx->br()->device_mr() - ) - ->release(); - // drop key columns from probe table. - std::vector to_keep; - std::ranges::copy_if( - std::ranges::iota_view(0, probe_table.num_columns()), - std::back_inserter(to_keep), - [&](auto i) { return std::ranges::find(right_on, i) == right_on.end(); } - ); - std::ranges::move( - cudf::gather( - probe_table.select(to_keep), - probe_indices, - cudf::out_of_bounds_policy::DONT_CHECK, - chunk_stream, - ctx->br()->device_mr() + + std::vector> result_columns; + auto gather_columns = [&](cudf::table_view table, cudf::column_view indices) { + auto gathered = cudf::gather( + table, + indices, + cudf::out_of_bounds_policy::DONT_CHECK, + chunk_stream, + ctx->br()->device_mr() ) - ->release(), - std::back_inserter(result_columns) - ); - // Deallocation of the join indices will happen on build_stream, so add stream dep - // This also ensure deallocation of the hash_join object waits for completion. - cuda_stream_join(build_stream, chunk_stream, &event); + ->release(); + std::ranges::move(gathered, std::back_inserter(result_columns)); + }; + + auto const broadcast_left = broadcast_side == BroadcastSide::LEFT; + auto probe_carrier = probe_table; + if (keep_keys == KeepKeys::NO || broadcast_left) { + auto probe_indices = + std::views::iota(cudf::size_type{0}, probe_table.num_columns()) + | std::views::filter([&](auto i) { + return std::ranges::find(probe_keys, i) == probe_keys.end(); + }); + probe_carrier = probe_carrier.select(probe_indices.begin(), probe_indices.end()); + } + + if (broadcast_left) { + gather_columns(build_carrier, build_match_indices); + gather_columns(probe_carrier, probe_match_indices); + } else { + gather_columns(probe_carrier, probe_match_indices); + gather_columns(build_carrier, build_match_indices); + } + + cuda_stream_join(build_stream, chunk_stream, dealloc_event); return streaming::to_message( sequence, std::make_unique( @@ -220,24 +238,30 @@ streaming::Message inner_join_chunk( ) ); } +} // namespace streaming::Node inner_join_broadcast( std::shared_ptr ctx, - // We will always choose left as build table and do "broadcast" joins std::shared_ptr left, std::shared_ptr right, std::shared_ptr ch_out, std::vector left_on, std::vector right_on, OpID tag, - KeepKeys keep_keys + KeepKeys keep_keys, + BroadcastSide broadcast_side ) { streaming::ShutdownAtExit c{left, right, ch_out}; co_await ctx->executor()->schedule(); ctx->comm()->logger().print("Inner broadcast join ", static_cast(tag)); + auto const broadcast_left = broadcast_side == BroadcastSide::LEFT; + auto build = broadcast_left ? left : right; + auto probe = broadcast_left ? right : left; + auto build_keys = broadcast_left ? left_on : right_on; + auto probe_keys = broadcast_left ? right_on : left_on; auto build_table = to_device( ctx, - (co_await broadcast(ctx, left, tag, streaming::AllGather::Ordered::NO)) + (co_await broadcast(ctx, build, tag, streaming::AllGather::Ordered::NO)) .release() ); ctx->comm()->logger().print( @@ -245,38 +269,39 @@ streaming::Node inner_join_broadcast( ); auto joiner = cudf::hash_join( - build_table.table_view().select(left_on), + build_table.table_view().select(build_keys), cudf::null_equality::UNEQUAL, build_table.stream() ); CudaEvent build_event; + CudaEvent dealloc_event; build_event.record(build_table.stream()); - cudf::table_view build_carrier; - if (keep_keys == KeepKeys::YES) { - build_carrier = build_table.table_view(); - } else { - std::vector to_keep; - std::ranges::copy_if( - std::ranges::iota_view(0, build_table.table_view().num_columns()), - std::back_inserter(to_keep), - [&](auto i) { return std::ranges::find(left_on, i) == left_on.end(); } - ); - build_carrier = build_table.table_view().select(to_keep); + auto build_carrier = build_table.table_view(); + if (keep_keys == KeepKeys::NO || !broadcast_left) { + auto build_indices = + std::views::iota(cudf::size_type{0}, build_carrier.num_columns()) + | std::views::filter([&](auto i) { + return std::ranges::find(build_keys, i) == build_keys.end(); + }); + build_carrier = build_carrier.select(build_indices.begin(), build_indices.end()); } while (!ch_out->is_shutdown()) { - auto right_msg = co_await right->receive(); - if (right_msg.empty()) { + auto probe_msg = co_await probe->receive(); + if (probe_msg.empty()) { break; } co_await ch_out->send(inner_join_chunk( ctx, - right_msg.release(), - right_msg.sequence_number(), + probe_msg.release(), + probe_msg.sequence_number(), joiner, build_carrier, - right_on, + probe_keys, + keep_keys, + broadcast_side, build_table.stream(), - &build_event + &build_event, + &dealloc_event )); } @@ -296,6 +321,7 @@ streaming::Node inner_join_shuffle( ctx->comm()->logger().print("Inner shuffle join"); co_await ctx->executor()->schedule(); CudaEvent build_event; + CudaEvent dealloc_event; while (!ch_out->is_shutdown()) { // Requirement: two shuffles kick out partitions in the same order auto left_msg = co_await left->receive(); @@ -319,17 +345,15 @@ streaming::Node inner_join_shuffle( build_stream ); build_event.record(build_stream); - cudf::table_view build_carrier; - if (keep_keys == KeepKeys::YES) { - build_carrier = build_chunk.table_view(); - } else { - std::vector to_keep; - std::ranges::copy_if( - std::ranges::iota_view(0, build_chunk.table_view().num_columns()), - std::back_inserter(to_keep), - [&](auto i) { return std::ranges::find(left_on, i) == left_on.end(); } - ); - build_carrier = build_chunk.table_view().select(to_keep); + auto build_carrier = build_chunk.table_view(); + if (keep_keys == KeepKeys::NO) { + auto build_indices = + std::views::iota(cudf::size_type{0}, build_carrier.num_columns()) + | std::views::filter([&](auto i) { + return std::ranges::find(left_on, i) == left_on.end(); + }); + build_carrier = + build_carrier.select(build_indices.begin(), build_indices.end()); } co_await ch_out->send(inner_join_chunk( ctx, @@ -338,8 +362,11 @@ streaming::Node inner_join_shuffle( joiner, build_carrier, right_on, + keep_keys, + BroadcastSide::LEFT, build_stream, - &build_event + &build_event, + &dealloc_event )); } co_await ch_out->drain(ctx->executor()); @@ -403,4 +430,207 @@ streaming::Node shuffle( co_await ch_out->drain(ctx->executor()); } +namespace { +coro::task> allgather_join_sizes( + std::shared_ptr ctx, + OpID tag, + std::size_t left_local_bytes, + std::size_t right_local_bytes +) { + auto metadata = std::make_unique>(2 * sizeof(std::size_t)); + std::memcpy(metadata->data(), &left_local_bytes, sizeof(left_local_bytes)); + std::memcpy( + metadata->data() + sizeof(left_local_bytes), + &right_local_bytes, + sizeof(right_local_bytes) + ); + + auto stream = ctx->br()->stream_pool().get_stream(); + auto [res, _] = ctx->br()->reserve(MemoryType::HOST, 0, true); + auto buf = ctx->br()->allocate(stream, std::move(res)); + auto allgather = streaming::AllGather(ctx, tag); + allgather.insert(0, {std::move(metadata), std::move(buf)}); + allgather.insert_finished(); + auto per_rank = co_await allgather.extract_all(streaming::AllGather::Ordered::NO); + + std::size_t left_total_bytes = 0; + std::size_t right_total_bytes = 0; + for (auto const& data : per_rank) { + RAPIDSMPF_EXPECTS( + data.metadata->size() >= 2 * sizeof(std::size_t), + "Invalid metadata size for adaptive join size estimation" + ); + std::size_t bytes = 0; + std::memcpy(&bytes, data.metadata->data(), sizeof(bytes)); + left_total_bytes += bytes; + std::memcpy(&bytes, data.metadata->data() + sizeof(bytes), sizeof(bytes)); + right_total_bytes += bytes; + } + co_return {left_total_bytes, right_total_bytes}; +} + +streaming::Node replay_channel( + std::shared_ptr ctx, + std::shared_ptr input, + std::shared_ptr output, + std::vector buffer +) { + streaming::ShutdownAtExit c{input, output}; + co_await ctx->executor()->schedule(); + for (auto&& msg : buffer) { + if (msg.empty()) { + co_await output->drain(ctx->executor()); + co_return; + } + if (!co_await output->send(std::move(msg))) { + co_return; + } + } + while (!output->is_shutdown()) { + auto msg = co_await input->receive(); + if (msg.empty()) { + break; + } + co_await output->send(std::move(msg)); + } + co_await output->drain(ctx->executor()); +} +} // namespace + +streaming::Node adaptive_inner_join( + std::shared_ptr ctx, + std::shared_ptr left, + std::shared_ptr right, + std::shared_ptr left_meta, + std::shared_ptr right_meta, + std::shared_ptr ch_out, + std::vector left_keys, + std::vector right_keys, + OpID allreduce_tag, + OpID left_shuffle_tag, + OpID right_shuffle_tag +) { + streaming::ShutdownAtExit c{left, right, left_meta, right_meta, ch_out}; + co_await ctx->executor()->schedule(); + + auto consume_meta = [&]( + std::shared_ptr ch + ) -> coro::task> { + auto msg = co_await ch->receive(); + if (msg.empty()) { + co_return std::nullopt; + } + co_return msg.release(); + }; + auto [num_left_messages, num_right_messages] = streaming::coro_results( + co_await coro::when_all(consume_meta(left_meta), consume_meta(right_meta)) + ); + // Assumption: the input channels carry only TableChunk messages. + // Assumption: summing data_alloc_size across memory types is a good proxy for the + // amount of data that will need to be materialized on device for compute. + // Assumption: a small sample of chunks is representative of the whole table size. + // Assumption: metadata estimates reflect the total number of chunks per input. + constexpr std::size_t broadcast_cap_bytes = 8ULL * 1024ULL * 1024ULL * 1024ULL; + constexpr double broadcast_ratio_threshold = 0.10; + constexpr std::size_t inspect_messages = 2; + std::vector left_buffer; + std::vector right_buffer; + left_buffer.reserve(inspect_messages); + right_buffer.reserve(inspect_messages); + + auto inspect_channel = + [&](std::shared_ptr ch, + std::vector& buffer, + std::size_t estimated_num_messages) -> coro::task { + std::size_t bytes = 0; + for (std::size_t i = 0; i < inspect_messages; ++i) { + auto msg = co_await ch->receive(); + if (msg.empty()) { + buffer.push_back(std::move(msg)); + co_return bytes; + } + auto const& chunk = msg.get(); + for (auto mem_type : MEMORY_TYPES) { + bytes += chunk.data_alloc_size(mem_type); + } + buffer.push_back(std::move(msg)); + } + co_return (bytes * estimated_num_messages) / inspect_messages; + }; + + auto left_local_bytes = co_await inspect_channel( + left, left_buffer, num_left_messages.value_or(inspect_messages) + ); + auto right_local_bytes = co_await inspect_channel( + right, right_buffer, num_right_messages.value_or(inspect_messages) + ); + + std::size_t left_total_bytes = left_local_bytes; + std::size_t right_total_bytes = right_local_bytes; + if (ctx->comm()->nranks() > 1) { + std::tie(left_total_bytes, right_total_bytes) = co_await allgather_join_sizes( + ctx, allreduce_tag, left_local_bytes, right_local_bytes + ); + } + + ctx->comm()->logger().print( + "Adaptive join total sizes: left ", + left_total_bytes, + " bytes, right ", + right_total_bytes, + " bytes" + ); + auto const min_bytes = std::min(left_total_bytes, right_total_bytes); + auto const max_bytes = std::max(left_total_bytes, right_total_bytes); + auto const broadcast_ratio = + (max_bytes == 0) ? 0.0 : static_cast(min_bytes) / max_bytes; + auto const use_broadcast = + min_bytes <= broadcast_cap_bytes && broadcast_ratio <= broadcast_ratio_threshold; + auto left_replay = ctx->create_channel(); + auto right_replay = ctx->create_channel(); + std::vector tasks; + tasks.push_back(replay_channel(ctx, left, left_replay, std::move(left_buffer))); + tasks.push_back(replay_channel(ctx, right, right_replay, std::move(right_buffer))); + if (use_broadcast) { + ctx->comm()->logger().print("Adaptive join strategy: broadcast"); + auto const broadcast_left = left_total_bytes <= right_total_bytes; + auto build = broadcast_left ? left_replay : right_replay; + auto probe = broadcast_left ? right_replay : left_replay; + auto build_keys = broadcast_left ? left_keys : right_keys; + auto probe_keys = broadcast_left ? right_keys : left_keys; + auto const broadcast_tag = broadcast_left ? left_shuffle_tag : right_shuffle_tag; + tasks.push_back(inner_join_broadcast( + ctx, + build, + probe, + ch_out, + std::move(build_keys), + std::move(probe_keys), + broadcast_tag, + KeepKeys::YES, + broadcast_left ? BroadcastSide::LEFT : BroadcastSide::RIGHT + )); + } else { + ctx->comm()->logger().print("Adaptive join strategy: shuffle"); + auto const num_partitions = static_cast(ctx->comm()->nranks()); + auto left_shuffled = ctx->create_channel(); + auto right_shuffled = ctx->create_channel(); + tasks.push_back(shuffle( + ctx, left_replay, left_shuffled, left_keys, num_partitions, left_shuffle_tag + )); + tasks.push_back(shuffle( + ctx, + right_replay, + right_shuffled, + right_keys, + num_partitions, + right_shuffle_tag + )); + tasks.push_back(inner_join_shuffle( + ctx, left_shuffled, right_shuffled, ch_out, left_keys, right_keys + )); + } + streaming::coro_results(co_await coro::when_all(std::move(tasks))); + co_await ch_out->drain(ctx->executor()); +} } // namespace rapidsmpf::ndsh diff --git a/cpp/benchmarks/streaming/ndsh/join.hpp b/cpp/benchmarks/streaming/ndsh/join.hpp index 67e701817..09f9721e9 100644 --- a/cpp/benchmarks/streaming/ndsh/join.hpp +++ b/cpp/benchmarks/streaming/ndsh/join.hpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 */ @@ -22,6 +22,12 @@ enum class KeepKeys : bool { YES, ///< Key columns do appear in the output }; +///< @brief Which input is the build side for a join. +enum class BroadcastSide : bool { + LEFT, ///< Broadcast the left input + RIGHT, ///< Broadcast the right input +}; + /** * @brief Broadcast the concatenation of all input messages to all ranks. * @@ -67,31 +73,33 @@ enum class KeepKeys : bool { /** * @brief Perform a streaming inner join between two tables. * - * @note This performs a broadcast join, broadcasting the table represented by the `left` - * channel to all ranks, and then streaming through the chunks of the `right` channel. + * @note This performs a broadcast join, broadcasting the chosen broadcast side to all + * ranks and then streaming through the chunks of the non-broadcasted side. Output + * ordering follows the logical left input: left keys (if kept), left non-keys, then right + * non-keys. * * @param ctx Streaming context. - * @param left Channel of `TableChunk`s used as the broadcasted build side. - * @param right Channel of `TableChunk`s joined in turn against the build side. + * @param left Channel of `TableChunk`s for the left input. + * @param right Channel of `TableChunk`s for the right input. * @param ch_out Output channel of `TableChunk`s. * @param left_on Column indices of the keys in the left table. * @param right_on Column indices of the keys in the right table. - * @param tag Disambiguating tag for the broadcast of the left table. - * @param keep_keys Does the result contain the key columns, or only "carrier" value - * columns + * @param tag Disambiguating tag for the broadcast. + * @param keep_keys Does the result contain the left key columns. + * @param broadcast_side Which input should be broadcast. * * @return Coroutine representing the completion of the join. */ [[nodiscard]] streaming::Node inner_join_broadcast( std::shared_ptr ctx, - // We will always choose left as build table and do "broadcast" joins std::shared_ptr left, std::shared_ptr right, std::shared_ptr ch_out, std::vector left_on, std::vector right_on, OpID tag, - KeepKeys keep_keys = KeepKeys::YES + KeepKeys keep_keys = KeepKeys::YES, + BroadcastSide broadcast_side = BroadcastSide::LEFT ); /** * @brief Perform a streaming inner join between two tables. @@ -141,4 +149,38 @@ enum class KeepKeys : bool { OpID tag ); +/** + * @brief Perform a streaming inner join between two tables using an adaptive strategy. + * + * @note This inspects a small prefix of each input channel to estimate input sizes and + * uses that estimate to choose between broadcast and shuffle join implementations. + * + * @param ctx Streaming context. + * @param left Channel of `TableChunk`s from the left table. + * @param right Channel of `TableChunk`s from the right table. + * @param left_meta Channel carrying metadata for the left input (optional). + * @param right_meta Channel carrying metadata for the right input (optional). + * @param ch_out Output channel of `TableChunk`s. + * @param left_keys Column indices of the keys in the left table. + * @param right_keys Column indices of the keys in the right table. + * @param allreduce_tag Disambiguating tag for the size estimation allgather. + * @param left_shuffle_tag Disambiguating tag for the left shuffle (if used). + * @param right_shuffle_tag Disambiguating tag for the right shuffle (if used). + * + * @return Coroutine representing the completion of the join. + */ +streaming::Node adaptive_inner_join( + std::shared_ptr ctx, + std::shared_ptr left, + std::shared_ptr right, + std::shared_ptr left_meta, + std::shared_ptr right_meta, + std::shared_ptr ch_out, + std::vector left_keys, + std::vector right_keys, + OpID allreduce_tag, + OpID left_shuffle_tag, + OpID right_shuffle_tag +); + } // namespace rapidsmpf::ndsh diff --git a/cpp/include/rapidsmpf/streaming/core/channel.hpp b/cpp/include/rapidsmpf/streaming/core/channel.hpp index 733126242..959a154a7 100644 --- a/cpp/include/rapidsmpf/streaming/core/channel.hpp +++ b/cpp/include/rapidsmpf/streaming/core/channel.hpp @@ -18,6 +18,7 @@ #include #include +#include #include namespace rapidsmpf::streaming { @@ -50,7 +51,7 @@ class Channel { * * @throws std::logic_error If the message is empty. */ - coro::task send(Message msg); + [[nodiscard]] coro::task send(Message msg); /** * @brief Asynchronously receive a message from the channel. @@ -62,7 +63,41 @@ class Channel { * * @throws std::logic_error If the received message is empty. */ - coro::task receive(); + [[nodiscard]] coro::task receive(); + + /** + * @brief Asynchronously send a metadata message into the channel. + * + * Suspends if the metadata queue is locked by another producer. + * + * @param msg The metadata message to send. + * @return A coroutine that evaluates to true if the msg was successfully sent or + * false if the channel was shut down. + * + * @throws std::logic_error If the message is empty. + */ + [[nodiscard]] coro::task send_metadata(Message msg); + + /** + * @brief Asynchronously receive a metadata message from the channel. + * + * Suspends if the metadata queue is empty. + * + * @return A coroutine that evaluates to the message, which will be empty if the + * metadata queue is shut down. + */ + [[nodiscard]] coro::task receive_metadata(); + + /** + * @brief Drains all pending metadata messages from the channel and shuts down the + * metadata channel. + * + * This is intended to ensure all remaining metadata messages are processed. + * + * @param executor The thread pool used to process remaining messages. + * @return A coroutine representing the completion of the metadata shutdown drain. + */ + [[nodiscard]] Node drain_metadata(std::shared_ptr executor); /** * @brief Drains all pending messages from the channel and shuts it down. @@ -72,7 +107,7 @@ class Channel { * @param executor The thread pool used to process remaining messages. * @return A coroutine representing the completion of the shutdown drain. */ - Node drain(std::shared_ptr executor); + [[nodiscard]] Node drain(std::shared_ptr executor); /** * @brief Immediately shuts down the channel. @@ -81,7 +116,16 @@ class Channel { * * @return A coroutine representing the completion of the shutdown. */ - Node shutdown(); + [[nodiscard]] Node shutdown(); + + /** + * @brief Immediately shuts down the metadata queue. + * + * Any pending or future metadata send/receive operations will complete with failure. + * + * @return A coroutine representing the completion of the shutdown. + */ + [[nodiscard]] Node shutdown_metadata(); /** * @brief Check whether the channel is empty. @@ -103,6 +147,7 @@ class Channel { coro::ring_buffer rb_; std::shared_ptr sm_; + coro::queue metadata_; }; /** diff --git a/cpp/include/rapidsmpf/streaming/core/lineariser.hpp b/cpp/include/rapidsmpf/streaming/core/lineariser.hpp index edad29951..237bb4da0 100644 --- a/cpp/include/rapidsmpf/streaming/core/lineariser.hpp +++ b/cpp/include/rapidsmpf/streaming/core/lineariser.hpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 */ @@ -94,7 +94,8 @@ class Lineariser { Node drain() { ShutdownAtExit c{ch_out_}; co_await ctx_->executor()->schedule(); - while (!queues_.empty()) { + bool should_continue = true; + while (should_continue && !queues_.empty()) { for (auto& q : queues_) { auto [receipt, msg] = co_await q->receive(); if (msg.empty()) { @@ -103,6 +104,7 @@ class Lineariser { } if (!co_await ch_out_->send(std::move(msg))) { // Output channel is shut down, tell the producers to shutdown. + should_continue = false; break; } co_await receipt; diff --git a/cpp/src/streaming/core/channel.cpp b/cpp/src/streaming/core/channel.cpp index dde57f138..1fdf3e02a 100644 --- a/cpp/src/streaming/core/channel.cpp +++ b/cpp/src/streaming/core/channel.cpp @@ -23,12 +23,39 @@ coro::task Channel::receive() { } } +coro::task Channel::send_metadata(Message msg) { + RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty"); + auto result = co_await metadata_.push(std::move(msg)); + co_return result == coro::queue_produce_result::produced; +} + +Node Channel::drain_metadata(std::shared_ptr executor) { + return metadata_.shutdown_drain(executor->get()); +} + +coro::task Channel::receive_metadata() { + auto msg = co_await metadata_.pop(); + if (msg.has_value()) { + co_return std::move(*msg); + } else { + co_return Message{}; + } +} + Node Channel::drain(std::shared_ptr executor) { - return rb_.shutdown_drain(executor->get()); + coro_results( + co_await coro::when_all( + rb_.shutdown_drain(executor->get()), drain_metadata(executor) + ) + ); } Node Channel::shutdown() { - return rb_.shutdown(); + coro_results(co_await coro::when_all(metadata_.shutdown(), rb_.shutdown())); +} + +Node Channel::shutdown_metadata() { + return metadata_.shutdown(); } bool Channel::empty() const noexcept { diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index d695ba3a2..388c81476 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -98,6 +98,7 @@ if(RAPIDSMPF_HAVE_STREAMING) target_sources( test_sources PRIVATE streaming/test_allgather.cpp + streaming/test_channel.cpp streaming/test_error_handling.cpp streaming/test_fanout.cpp streaming/test_leaf_node.cpp diff --git a/cpp/tests/streaming/test_channel.cpp b/cpp/tests/streaming/test_channel.cpp new file mode 100644 index 000000000..54a89fa1d --- /dev/null +++ b/cpp/tests/streaming/test_channel.cpp @@ -0,0 +1,319 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include + +#include +#include +#include +#include +#include + +#include "base_streaming_fixture.hpp" + +using namespace rapidsmpf::streaming; + +namespace { +std::vector make_int_messages(std::size_t n) { + std::vector messages; + messages.reserve(n); + for (std::size_t i = 0; i < n; ++i) { + messages.emplace_back( + i, std::make_unique(i), rapidsmpf::ContentDescription{} + ); + } + return messages; +} +} // namespace + +using BaseStreamingChannel = BaseStreamingFixture; + +TEST_F(BaseStreamingChannel, DataRoundTripWithoutMetadata) { + auto ch = ctx->create_channel(); + std::vector outputs; + std::vector nodes; + static constexpr std::size_t num_messages = 4; + nodes.emplace_back(node::push_to_channel(ctx, ch, make_int_messages(num_messages))); + nodes.emplace_back(node::pull_from_channel(ctx, ch, outputs)); + run_streaming_pipeline(std::move(nodes)); + + ASSERT_EQ(outputs.size(), num_messages); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(outputs[static_cast(i)].release(), i); + } +} + +TEST_F(BaseStreamingChannel, MetadataSendReceiveAndShutdown) { + auto ch = ctx->create_channel(); + std::vector metadata_outputs; + std::vector data_outputs; + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + auto meta_task = [&]() -> Node { + co_await ch->send_metadata(Message{0, std::make_unique(10), {}}); + co_await ch->send_metadata(Message{1, std::make_unique(20), {}}); + co_await ch->drain_metadata(ctx->executor()); + }; + auto send_task = [&]() -> Node { + co_await ch->send(Message{0, std::make_unique(1), {}}); + co_await ch->send(Message{1, std::make_unique(2), {}}); + co_await ch->drain(ctx->executor()); + }; + coro_results(co_await coro::when_all(meta_task(), send_task())); + }; + + auto consumer = [this, ch, &metadata_outputs, &data_outputs]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + while (true) { + auto msg = co_await ch->receive_metadata(); + if (msg.empty()) { + break; + } + metadata_outputs.push_back(std::move(msg)); + } + + while (true) { + auto msg = co_await ch->receive(); + if (msg.empty()) { + break; + } + data_outputs.push_back(std::move(msg)); + } + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + run_streaming_pipeline(std::move(nodes)); + + ASSERT_EQ(metadata_outputs.size(), 2U); + EXPECT_EQ(metadata_outputs[0].get(), 10); + EXPECT_EQ(metadata_outputs[1].get(), 20); + + ASSERT_EQ(data_outputs.size(), 2U); + EXPECT_EQ(data_outputs[0].get(), 1); + EXPECT_EQ(data_outputs[1].get(), 2); +} + +TEST_F(BaseStreamingChannel, DataOnlyWithMetadataShutdown) { + auto ch = ctx->create_channel(); + std::vector data_outputs; + std::vector metadata_outputs; + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + co_await ch->shutdown_metadata(); + co_await ch->send( + Message{0, std::make_unique(10), rapidsmpf::ContentDescription{}} + ); + co_await ch->send( + Message{1, std::make_unique(20), rapidsmpf::ContentDescription{}} + ); + co_await ch->drain(ctx->executor()); + }; + + auto consumer = [this, ch, &metadata_outputs, &data_outputs]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + while (true) { + auto msg = co_await ch->receive_metadata(); + if (msg.empty()) { + break; + } + metadata_outputs.push_back(std::move(msg)); + } + + while (true) { + auto msg = co_await ch->receive(); + if (msg.empty()) { + break; + } + data_outputs.push_back(std::move(msg)); + } + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + run_streaming_pipeline(std::move(nodes)); + + EXPECT_TRUE(metadata_outputs.empty()); + ASSERT_EQ(data_outputs.size(), 2U); + EXPECT_EQ(data_outputs[0].get(), 10); + EXPECT_EQ(data_outputs[1].get(), 20); +} + +TEST_F(BaseStreamingChannel, MetadataOnlyWithDataShutdown) { + auto ch = ctx->create_channel(); + std::vector metadata_outputs; + std::vector data_outputs; + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + co_await ch->send_metadata( + Message{0, std::make_unique(10), rapidsmpf::ContentDescription{}} + ); + co_await ch->send_metadata( + Message{1, std::make_unique(20), rapidsmpf::ContentDescription{}} + ); + co_await ch->drain(ctx->executor()); + }; + + auto consumer = [this, ch, &metadata_outputs, &data_outputs]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + while (true) { + auto msg = co_await ch->receive_metadata(); + if (msg.empty()) { + break; + } + metadata_outputs.push_back(std::move(msg)); + } + + while (true) { + auto msg = co_await ch->receive(); + if (msg.empty()) { + break; + } + data_outputs.push_back(std::move(msg)); + } + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + run_streaming_pipeline(std::move(nodes)); + + ASSERT_EQ(metadata_outputs.size(), 2U); + EXPECT_EQ(metadata_outputs[0].get(), 10); + EXPECT_EQ(metadata_outputs[1].get(), 20); + EXPECT_TRUE(data_outputs.empty()); +} + +TEST_F(BaseStreamingChannel, ConsumerIgnoresMetadata) { + auto ch = ctx->create_channel(); + std::vector data_outputs; + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + + co_await ch->send_metadata( + Message{0, std::make_unique(10), rapidsmpf::ContentDescription{}} + ); + co_await ch->send_metadata( + Message{0, std::make_unique(20), rapidsmpf::ContentDescription{}} + ); + co_await ch->send( + Message{1, std::make_unique(30), rapidsmpf::ContentDescription{}} + ); + co_await ch->drain(ctx->executor()); + }; + + auto consumer = [this, ch, &data_outputs]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + while (true) { + auto msg = co_await ch->receive(); + if (msg.empty()) { + break; + } + data_outputs.push_back(std::move(msg)); + } + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + run_streaming_pipeline(std::move(nodes)); + + EXPECT_EQ(data_outputs.size(), 1U); + EXPECT_EQ(data_outputs[0].get(), 30); +} + +TEST_F(BaseStreamingChannel, ProducerThrowsWithMetadata) { + auto ch = ctx->create_channel(); + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + co_await ch->send_metadata( + Message{0, std::make_unique(31), rapidsmpf::ContentDescription{}} + ); + throw std::runtime_error("producer failed"); + }; + + auto consumer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + while (true) { + auto msg = co_await ch->receive_metadata(); + if (msg.empty()) { + break; + } + } + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + EXPECT_THROW(run_streaming_pipeline(std::move(nodes)), std::runtime_error); +} + +TEST_F(BaseStreamingChannel, ConsumerThrowsWithMetadata) { + auto ch = ctx->create_channel(); + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + co_await ch->send_metadata( + Message{0, std::make_unique(10), rapidsmpf::ContentDescription{}} + ); + co_await ch->drain(ctx->executor()); + }; + + auto consumer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + throw std::runtime_error("consumer failed"); + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + EXPECT_THROW(run_streaming_pipeline(std::move(nodes)), std::runtime_error); +} + +TEST_F(BaseStreamingChannel, ProducerAndConsumerThrow) { + auto ch = ctx->create_channel(); + std::vector nodes; + + auto producer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + throw std::runtime_error("producer failed"); + }; + + auto consumer = [this, ch]() -> Node { + ShutdownAtExit c{ch}; + co_await ctx->executor()->schedule(); + throw std::runtime_error("consumer failed"); + }; + + nodes.emplace_back(producer()); + nodes.emplace_back(consumer()); + EXPECT_THROW(run_streaming_pipeline(std::move(nodes)), std::runtime_error); +}