Skip to content

Commit d636093

Browse files
authored
Asynchronous setup (#514)
Cherry-picked a part of features from #167: now `Communicator::setup()` is unneeded. `Communicator::sendMemory()` conducts the task inline, and `Communicator::recvMemory()` and `Communicator::connect()` conducts the task asynchronously without explicit setup.
1 parent 8bc369c commit d636093

File tree

7 files changed

+59
-180
lines changed

7 files changed

+59
-180
lines changed

include/mscclpp/core.hpp

Lines changed: 40 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -557,64 +557,19 @@ class Context {
557557
friend class Endpoint;
558558
};
559559

560-
/// A base class for objects that can be set up during @ref Communicator::setup().
561-
struct Setuppable {
562-
virtual ~Setuppable() = default;
563-
564-
/// Called inside @ref Communicator::setup() before any call to @ref endSetup() of any @ref Setuppable object that is
565-
/// being set up within the same @ref Communicator::setup() call.
566-
///
567-
/// @param bootstrap A shared pointer to the bootstrap implementation.
568-
virtual void beginSetup(std::shared_ptr<Bootstrap> bootstrap);
569-
570-
/// Called inside @ref Communicator::setup() after all calls to @ref beginSetup() of all @ref Setuppable objects that
571-
/// are being set up within the same @ref Communicator::setup() call.
572-
///
573-
/// @param bootstrap A shared pointer to the bootstrap implementation.
574-
virtual void endSetup(std::shared_ptr<Bootstrap> bootstrap);
575-
};
576-
577-
/// A non-blocking future that can be used to check if a value is ready and retrieve it.
578560
template <typename T>
579-
class NonblockingFuture {
580-
std::shared_future<T> future;
581-
582-
public:
583-
/// Default constructor.
584-
NonblockingFuture() = default;
585-
586-
/// Constructor that takes a shared future and moves it into the NonblockingFuture.
587-
///
588-
/// @param future The shared future to move.
589-
NonblockingFuture(std::shared_future<T>&& future) : future(std::move(future)) {}
590-
591-
/// Check if the value is ready to be retrieved.
592-
///
593-
/// @return True if the value is ready, false otherwise.
594-
bool ready() const { return future.wait_for(std::chrono::seconds(0)) == std::future_status::ready; }
595-
596-
/// Get the value.
597-
///
598-
/// @return The value.
599-
///
600-
/// @throws Error if the value is not ready.
601-
T get() const {
602-
if (!ready()) throw Error("NonblockingFuture::get() called before ready", ErrorCode::InvalidUsage);
603-
return future.get();
604-
}
605-
};
561+
using NonblockingFuture [[deprecated("Use std::shared_future instead. This will be removed in a future release.")]] =
562+
std::shared_future<T>;
606563

607564
/// A class that sets up all registered memories and connections between processes.
608565
///
609566
/// A typical way to use this class:
610-
/// 1. Call @ref connectOnSetup() to declare connections between the calling process with other processes.
611-
/// 2. Call @ref registerMemory() to register memory regions that will be used for communication.
612-
/// 3. Call @ref sendMemoryOnSetup() or @ref recvMemoryOnSetup() to send/receive registered memory regions to/from
567+
/// 1. Call connect() to declare connections between the calling process with other processes.
568+
/// 2. Call registerMemory() to register memory regions that will be used for communication.
569+
/// 3. Call sendMemory() or recvMemory() to send/receive registered memory regions to/from
613570
/// other processes.
614-
/// 4. Call @ref setup() to set up all registered memories and connections declared in the previous steps.
615-
/// 5. Call @ref NonblockingFuture<RegisteredMemory>::get() to get the registered memory regions received from other
616-
/// processes.
617-
/// 6. All done; use connections and registered memories to build channels.
571+
/// 4. Call get() on all futures returned by connect() and recvMemory().
572+
/// 5. All done; use connections and registered memories to build channels.
618573
///
619574
class Communicator {
620575
public:
@@ -645,40 +600,57 @@ class Communicator {
645600
/// @return RegisteredMemory A handle to the buffer.
646601
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);
647602

648-
/// Send information of a registered memory to the remote side on setup.
603+
/// Send information of a registered memory to the remote side.
649604
///
650-
/// This function registers a send to a remote process that will happen by a following call of @ref setup(). The send
651-
/// will carry information about a registered memory on the local process.
605+
/// The send will be performed immediately upon calling this function.
652606
///
653607
/// @param memory The registered memory buffer to send information about.
654608
/// @param remoteRank The rank of the remote process.
655609
/// @param tag The tag to use for identifying the send.
656-
void sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag);
610+
void sendMemory(RegisteredMemory memory, int remoteRank, int tag);
657611

658-
/// Receive memory on setup.
612+
[[deprecated("Use sendMemory() instead. This will be removed in a future release.")]] void sendMemoryOnSetup(
613+
RegisteredMemory memory, int remoteRank, int tag) {
614+
sendMemory(memory, remoteRank, tag);
615+
}
616+
617+
/// Receive memory information from a corresponding sendMemory call on the remote side.
659618
///
660-
/// This function registers a receive from a remote process that will happen by a following call of @ref setup(). The
661-
/// receive will carry information about a registered memory on the remote process.
619+
/// This function returns a future immediately. The actual receive will be performed upon calling
620+
/// the first get() on the future.
662621
///
663622
/// @param remoteRank The rank of the remote process.
664623
/// @param tag The tag to use for identifying the receive.
665-
/// @return NonblockingFuture<RegisteredMemory> A non-blocking future of registered memory.
666-
NonblockingFuture<RegisteredMemory> recvMemoryOnSetup(int remoteRank, int tag);
624+
/// @return std::shared_future<RegisteredMemory> A non-blocking future of registered memory.
625+
std::shared_future<RegisteredMemory> recvMemory(int remoteRank, int tag);
667626

668-
/// Connect to a remote rank on setup.
627+
[[deprecated(
628+
"Use recvMemory() instead. This will be removed in a future release.")]] NonblockingFuture<RegisteredMemory>
629+
recvMemoryOnSetup(int remoteRank, int tag) {
630+
return recvMemory(remoteRank, tag);
631+
}
632+
633+
/// Connect to a remote rank.
669634
///
670-
/// This function only prepares metadata for connection. The actual connection is made by a following call of
671-
/// @ref setup(). Note that this function is two-way and a connection from rank `i` to remote rank `j` needs
635+
/// This function will immediately send metadata about the local endpoint to the remote rank, and return a future
636+
/// without waiting for the remote rank to respond. The connection will be established when the remote rank
637+
/// responds with its own endpoint and the local rank calls the first get() on the future.
638+
/// Note that this function is two-way and a connection from rank `i` to remote rank `j` needs
672639
/// to have a counterpart from rank `j` to rank `i`. Note that with IB, buffers are registered at a page level and if
673640
/// a buffer is spread through multiple pages and do not fully utilize all of them, IB's QP has to register for all
674641
/// involved pages. This potentially has security risks if the connection's accesses are given to a malicious process.
675642
///
676643
/// @param remoteRank The rank of the remote process.
677644
/// @param tag The tag of the connection for identifying it.
678645
/// @param config The configuration for the local endpoint.
679-
/// @return NonblockingFuture<NonblockingFuture<std::shared_ptr<Connection>>> A non-blocking future of shared pointer
680-
/// to the connection.
681-
NonblockingFuture<std::shared_ptr<Connection>> connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig);
646+
/// @return std::shared_future<std::shared_ptr<Connection>> A non-blocking future of shared pointer to the connection.
647+
std::shared_future<std::shared_ptr<Connection>> connect(int remoteRank, int tag, EndpointConfig localConfig);
648+
649+
[[deprecated("Use connect() instead. This will be removed in a future release.")]] NonblockingFuture<
650+
std::shared_ptr<Connection>>
651+
connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig) {
652+
return connect(remoteRank, tag, localConfig);
653+
}
682654

683655
/// Get the remote rank a connection is connected to.
684656
///
@@ -692,17 +664,7 @@ class Communicator {
692664
/// @return The tag the connection was made with.
693665
int tagOf(const Connection& connection);
694666

695-
/// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called.
696-
///
697-
/// @param setuppable A shared pointer to the Setuppable object.
698-
void onSetup(std::shared_ptr<Setuppable> setuppable);
699-
700-
/// Setup all objects that have registered for setup.
701-
///
702-
/// This includes previous calls of @ref sendMemoryOnSetup(), @ref recvMemoryOnSetup(), @ref connectOnSetup(), and
703-
/// @ref onSetup(). It is allowed to call this function multiple times, where the n-th call will only setup objects
704-
/// that have been registered after the (n-1)-th call.
705-
void setup();
667+
[[deprecated("setup() is now no-op and no longer needed. This will be removed in a future release.")]] void setup() {}
706668

707669
private:
708670
// The interal implementation.

include/mscclpp/semaphore.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ template <template <typename> typename InboundDeleter, template <typename> typen
3030
class BaseSemaphore {
3131
protected:
3232
/// The registered memory for the remote peer's inbound semaphore ID.
33-
NonblockingFuture<RegisteredMemory> remoteInboundSemaphoreIdsRegMem_;
33+
std::shared_future<RegisteredMemory> remoteInboundSemaphoreIdsRegMem_;
3434

3535
/// The inbound semaphore ID that is incremented by the remote peer and waited on by the local peer.
3636
///

python/mscclpp/core_py.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ extern void register_gpu_utils(nb::module_& m);
2929
template <typename T>
3030
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
3131
std::string pyclass_name = std::string("NonblockingFuture") + typestr;
32-
nb::class_<NonblockingFuture<T>>(m, pyclass_name.c_str())
33-
.def("ready", &NonblockingFuture<T>::ready)
34-
.def("get", &NonblockingFuture<T>::get);
32+
nb::class_<NonblockingFuture<T>>(m, pyclass_name.c_str()).def("get", &NonblockingFuture<T>::get);
3533
}
3634

3735
void register_core(nb::module_& m) {

src/communicator.cc

Lines changed: 17 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -30,79 +30,31 @@ MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t
3030
return context()->registerMemory(ptr, size, transports);
3131
}
3232

33-
struct MemorySender : public Setuppable {
34-
MemorySender(RegisteredMemory memory, int remoteRank, int tag)
35-
: memory_(memory), remoteRank_(remoteRank), tag_(tag) {}
36-
37-
void beginSetup(std::shared_ptr<Bootstrap> bootstrap) override {
38-
bootstrap->send(memory_.serialize(), remoteRank_, tag_);
39-
}
40-
41-
RegisteredMemory memory_;
42-
int remoteRank_;
43-
int tag_;
44-
};
45-
46-
MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag) {
47-
onSetup(std::make_shared<MemorySender>(memory, remoteRank, tag));
33+
MSCCLPP_API_CPP void Communicator::sendMemory(RegisteredMemory memory, int remoteRank, int tag) {
34+
pimpl_->bootstrap_->send(memory.serialize(), remoteRank, tag);
4835
}
4936

50-
struct MemoryReceiver : public Setuppable {
51-
MemoryReceiver(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {}
52-
53-
void endSetup(std::shared_ptr<Bootstrap> bootstrap) override {
37+
MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(int remoteRank, int tag) {
38+
return std::async(std::launch::deferred, [this, remoteRank, tag]() {
5439
std::vector<char> data;
55-
bootstrap->recv(data, remoteRank_, tag_);
56-
memoryPromise_.set_value(RegisteredMemory::deserialize(data));
57-
}
58-
59-
std::promise<RegisteredMemory> memoryPromise_;
60-
int remoteRank_;
61-
int tag_;
62-
};
63-
64-
MSCCLPP_API_CPP NonblockingFuture<RegisteredMemory> Communicator::recvMemoryOnSetup(int remoteRank, int tag) {
65-
auto memoryReceiver = std::make_shared<MemoryReceiver>(remoteRank, tag);
66-
onSetup(memoryReceiver);
67-
return NonblockingFuture<RegisteredMemory>(memoryReceiver->memoryPromise_.get_future());
40+
bootstrap()->recv(data, remoteRank, tag);
41+
return RegisteredMemory::deserialize(data);
42+
});
6843
}
6944

70-
struct Communicator::Impl::Connector : public Setuppable {
71-
Connector(Communicator& comm, Communicator::Impl& commImpl_, int remoteRank, int tag, EndpointConfig localConfig)
72-
: comm_(comm),
73-
commImpl_(commImpl_),
74-
remoteRank_(remoteRank),
75-
tag_(tag),
76-
localEndpoint_(comm.context()->createEndpoint(localConfig)) {}
77-
78-
void beginSetup(std::shared_ptr<Bootstrap> bootstrap) override {
79-
bootstrap->send(localEndpoint_.serialize(), remoteRank_, tag_);
80-
}
45+
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
46+
EndpointConfig localConfig) {
47+
auto localEndpoint = pimpl_->context_->createEndpoint(localConfig);
48+
pimpl_->bootstrap_->send(localEndpoint.serialize(), remoteRank, tag);
8149

82-
void endSetup(std::shared_ptr<Bootstrap> bootstrap) override {
50+
return std::async(std::launch::deferred, [this, remoteRank, tag, localEndpoint = std::move(localEndpoint)]() mutable {
8351
std::vector<char> data;
84-
bootstrap->recv(data, remoteRank_, tag_);
52+
bootstrap()->recv(data, remoteRank, tag);
8553
auto remoteEndpoint = Endpoint::deserialize(data);
86-
auto connection = comm_.context()->connect(localEndpoint_, remoteEndpoint);
87-
commImpl_.connectionInfos_[connection.get()] = {remoteRank_, tag_};
88-
connectionPromise_.set_value(connection);
89-
INFO(MSCCLPP_INIT, "Connection %d -> %d created (%s)", comm_.bootstrap()->getRank(), remoteRank_,
90-
connection->getTransportName().c_str());
91-
}
92-
93-
std::promise<std::shared_ptr<Connection>> connectionPromise_;
94-
Communicator& comm_;
95-
Communicator::Impl& commImpl_;
96-
int remoteRank_;
97-
int tag_;
98-
Endpoint localEndpoint_;
99-
};
100-
101-
MSCCLPP_API_CPP NonblockingFuture<std::shared_ptr<Connection>> Communicator::connectOnSetup(
102-
int remoteRank, int tag, EndpointConfig localConfig) {
103-
auto connector = std::make_shared<Communicator::Impl::Connector>(*this, *pimpl_, remoteRank, tag, localConfig);
104-
onSetup(connector);
105-
return NonblockingFuture<std::shared_ptr<Connection>>(connector->connectionPromise_.get_future());
54+
auto connection = context()->connect(localEndpoint, remoteEndpoint);
55+
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
56+
return connection;
57+
});
10658
}
10759

10860
MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) {
@@ -113,18 +65,4 @@ MSCCLPP_API_CPP int Communicator::tagOf(const Connection& connection) {
11365
return pimpl_->connectionInfos_.at(&connection).tag;
11466
}
11567

116-
MSCCLPP_API_CPP void Communicator::onSetup(std::shared_ptr<Setuppable> setuppable) {
117-
pimpl_->toSetup_.push_back(setuppable);
118-
}
119-
120-
MSCCLPP_API_CPP void Communicator::setup() {
121-
for (auto& setuppable : pimpl_->toSetup_) {
122-
setuppable->beginSetup(pimpl_->bootstrap_);
123-
}
124-
for (auto& setuppable : pimpl_->toSetup_) {
125-
setuppable->endSetup(pimpl_->bootstrap_);
126-
}
127-
pimpl_->toSetup_.clear();
128-
}
129-
13068
} // namespace mscclpp

src/core.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,6 @@ const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transpo
8989

9090
const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc | Transport::Ethernet;
9191

92-
void Setuppable::beginSetup(std::shared_ptr<Bootstrap>) {}
93-
94-
void Setuppable::endSetup(std::shared_ptr<Bootstrap>) {}
95-
9692
} // namespace mscclpp
9793

9894
namespace std {

src/include/communicator.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ struct Communicator::Impl {
2222
std::shared_ptr<Bootstrap> bootstrap_;
2323
std::shared_ptr<Context> context_;
2424
std::unordered_map<const Connection*, ConnectionInfo> connectionInfos_;
25-
std::vector<std::shared_ptr<Setuppable>> toSetup_;
2625

2726
Impl(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context);
2827

test/unit/core_tests.cc

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,6 @@ class LocalCommunicatorTest : public ::testing::Test {
1818
std::shared_ptr<mscclpp::Communicator> comm;
1919
};
2020

21-
class MockSetuppable : public mscclpp::Setuppable {
22-
public:
23-
MOCK_METHOD(void, beginSetup, (std::shared_ptr<mscclpp::Bootstrap> bootstrap), (override));
24-
MOCK_METHOD(void, endSetup, (std::shared_ptr<mscclpp::Bootstrap> bootstrap), (override));
25-
};
26-
27-
TEST_F(LocalCommunicatorTest, OnSetup) {
28-
auto mockSetuppable = std::make_shared<MockSetuppable>();
29-
comm->onSetup(mockSetuppable);
30-
EXPECT_CALL(*mockSetuppable, beginSetup(std::dynamic_pointer_cast<mscclpp::Bootstrap>(bootstrap)));
31-
EXPECT_CALL(*mockSetuppable, endSetup(std::dynamic_pointer_cast<mscclpp::Bootstrap>(bootstrap)));
32-
comm->setup();
33-
}
34-
3521
TEST_F(LocalCommunicatorTest, RegisterMemory) {
3622
int dummy[42];
3723
auto memory = comm->registerMemory(&dummy, sizeof(dummy), mscclpp::NoTransports);

0 commit comments

Comments
 (0)