Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions cpp/include/rapidsmpf/shuffler/shuffler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <rapidsmpf/buffer/packed_data.hpp>
#include <rapidsmpf/buffer/resource.hpp>
#include <rapidsmpf/communicator/communicator.hpp>
#include <rapidsmpf/communicator/metadata_payload_exchange/tag.hpp>
#include <rapidsmpf/error.hpp>
#include <rapidsmpf/nvtx.hpp>
#include <rapidsmpf/progress_thread.hpp>
Expand Down Expand Up @@ -94,6 +95,8 @@ class Shuffler {
* @param finished_callback Callback to notify when a partition is finished.
* @param statistics The statistics instance to use (disabled by default).
* @param partition_owner Function to determine partition ownership.
* @param mpe Optional custom metadata payload exchange. If not provided,
* uses the default tag-based implementation.
*
* @note The caller promises that inserted buffers are stream-ordered with respect
* to their own stream, and extracted buffers are likewise guaranteed to be stream-
Expand All @@ -107,7 +110,8 @@ class Shuffler {
BufferResource* br,
FinishedCallback&& finished_callback,
std::shared_ptr<Statistics> statistics = Statistics::disabled(),
PartitionOwner partition_owner = round_robin
PartitionOwner partition_owner = round_robin,
std::unique_ptr<communicator::MetadataPayloadExchange> mpe = nullptr
);

/**
Expand All @@ -121,6 +125,8 @@ class Shuffler {
* @param br Buffer resource used to allocate temporary and the shuffle result.
* @param statistics The statistics instance to use (disabled by default).
* @param partition_owner Function to determine partition ownership.
* @param mpe Optional custom metadata payload exchange. If not provided,
* uses the default tag-based implementation.
*
* @note The caller promises that inserted buffers are stream-ordered with respect
* to their own stream, and extracted buffers are likewise guaranteed to be stream-
Expand All @@ -133,7 +139,8 @@ class Shuffler {
PartID total_num_partitions,
BufferResource* br,
std::shared_ptr<Statistics> statistics = Statistics::disabled(),
PartitionOwner partition_owner = round_robin
PartitionOwner partition_owner = round_robin,
std::unique_ptr<communicator::MetadataPayloadExchange> mpe = nullptr
)
: Shuffler(
comm,
Expand All @@ -143,7 +150,8 @@ class Shuffler {
br,
nullptr,
statistics,
partition_owner
partition_owner,
std::move(mpe)
) {}

~Shuffler();
Expand Down Expand Up @@ -348,6 +356,7 @@ class Shuffler {
///< ready to be extracted by the user.

std::shared_ptr<Communicator> comm_;
std::unique_ptr<communicator::MetadataPayloadExchange> mpe_;
std::shared_ptr<ProgressThread> progress_thread_;
ProgressThread::FunctionID progress_thread_function_id_;
OpID const op_id_;
Expand Down
18 changes: 12 additions & 6 deletions cpp/src/communicator/metadata_payload_exchange/tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,16 @@ void TagMetadataPayloadExchange::send(
);

// Send data immediately after metadata (if any)
if (message->data() != nullptr) {
if (payload_size > 0) {
fire_and_forget_.push_back(
comm_->send(message->release_data(), dst, gpu_data_tag_)
);
}
}

statistics_->add_duration_stat("comms-interface-send-messages", Clock::now() - t0);
statistics_->add_duration_stat(
"metadata-payload-exchange-send-messages", Clock::now() - t0
);
}

void TagMetadataPayloadExchange::progress() {
Expand All @@ -111,7 +113,9 @@ void TagMetadataPayloadExchange::progress() {

cleanup_completed_operations();

statistics_->add_duration_stat("comms-interface-progress", Clock::now() - t0);
statistics_->add_duration_stat(
"metadata-payload-exchange-progress", Clock::now() - t0
);
}

std::vector<std::unique_ptr<MetadataPayloadExchange::Message>>
Expand Down Expand Up @@ -178,7 +182,9 @@ void TagMetadataPayloadExchange::receive_metadata() {
);
}

statistics_->add_duration_stat("comms-interface-receive-metadata", Clock::now() - t0);
statistics_->add_duration_stat(
"metadata-payload-exchange-receive-metadata", Clock::now() - t0
);
}

std::vector<std::unique_ptr<MetadataPayloadExchange::Message>>
Expand Down Expand Up @@ -259,7 +265,7 @@ TagMetadataPayloadExchange::setup_data_receives() {
}

statistics_->add_duration_stat(
"comms-interface-setup-data-receives", Clock::now() - t0
"metadata-payload-exchange-setup-data-receives", Clock::now() - t0
);

return completed_messages;
Expand Down Expand Up @@ -320,7 +326,7 @@ TagMetadataPayloadExchange::complete_data_transfers() {
}

statistics_->add_duration_stat(
"comms-interface-complete-data-transfers", Clock::now() - t0
"metadata-payload-exchange-complete-data-transfers", Clock::now() - t0
);

return completed_messages;
Expand Down
Loading
Loading