@@ -127,8 +127,8 @@ class TcpBootstrap : public Bootstrap {
127
127
// / @return The unique ID stored in the TcpBootstrap.
128
128
UniqueId getUniqueId () const ;
129
129
130
- // / Initialize the TcpBootstrap with a given unique ID. The unique ID can be generated by any methods ;
131
- // / it can be created by createUniqueId() or can be any arbitrary bit arrays provided by the user.
130
+ // / Initialize the TcpBootstrap with a given unique ID. The unique ID can be generated by any method ;
131
+ // / it can be created by createUniqueId() or can be any arbitrary bit array provided by the user.
132
132
// / @param uniqueId The unique ID to initialize the TcpBootstrap with.
133
133
// / @param timeoutSec The connection timeout in seconds.
134
134
void initialize (UniqueId uniqueId, int64_t timeoutSec = 30 );
@@ -453,7 +453,7 @@ class Endpoint {
453
453
// / @return A vector of characters representing the serialized Endpoint object.
454
454
std::vector<char > serialize ();
455
455
456
- // / Deserialize a Endpoint object from a vector of characters.
456
+ // / Deserialize an Endpoint object from a vector of characters.
457
457
// /
458
458
// / @param data A vector of characters representing a serialized Endpoint object.
459
459
// / @return A deserialized Endpoint object.
@@ -473,8 +473,10 @@ class Connection {
473
473
public:
474
474
// / Constructor.
475
475
// / @param maxWriteQueueSize The maximum number of write requests that can be queued.
476
- Connection (int maxWriteQueueSize) : maxWriteQueueSize(maxWriteQueueSize){};
476
+ Connection (std::shared_ptr<Context> context, int maxWriteQueueSize)
477
+ : context_(context), maxWriteQueueSize_(maxWriteQueueSize){};
477
478
479
+ // / Destructor.
478
480
virtual ~Connection () = default ;
479
481
480
482
// / Write data from a source RegisteredMemory to a destination RegisteredMemory.
@@ -487,7 +489,7 @@ class Connection {
487
489
virtual void write (RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
488
490
uint64_t size) = 0;
489
491
490
- // / Update a 8-byte value in a destination RegisteredMemory and synchronize the change with the remote process.
492
+ // / Update an 8-byte value in a destination RegisteredMemory and synchronize the change with the remote process.
491
493
// /
492
494
// / @param dst The destination RegisteredMemory.
493
495
// / @param dstOffset The offset in bytes from the start of the destination RegisteredMemory.
@@ -522,7 +524,9 @@ class Connection {
522
524
// Internal methods for getting implementation pointers.
523
525
static std::shared_ptr<RegisteredMemory::Impl> getImpl (RegisteredMemory& memory);
524
526
static std::shared_ptr<Endpoint::Impl> getImpl (Endpoint& memory);
525
- int maxWriteQueueSize;
527
+
528
+ std::shared_ptr<Context> context_;
529
+ int maxWriteQueueSize_;
526
530
};
527
531
528
532
// / Used to configure an endpoint.
@@ -567,19 +571,19 @@ struct EndpointConfig {
567
571
// / 1. The client creates an endpoint with createEndpoint() and sends it to the server.
568
572
// / 2. The server receives the client endpoint, creates its own endpoint with createEndpoint(), sends it to the
569
573
// / client, and creates a connection with connect().
570
- // / 4 . The client receives the server endpoint, creates a connection with connect() and sends a
574
+ // / 3 . The client receives the server endpoint, creates a connection with connect() and sends a
571
575
// / RegisteredMemory to the server.
572
- // / 5 . The server receives the RegisteredMemory and writes to it using the previously created connection.
573
- // / The client waiting to create a connection before sending the RegisteredMemory ensures that the server can not
576
+ // / 4 . The server receives the RegisteredMemory and writes to it using the previously created connection.
577
+ // / The client waiting to create a connection before sending the RegisteredMemory ensures that the server cannot
574
578
// / write to the RegisteredMemory before the connection is established.
575
579
// /
576
580
// / While some transports may have more relaxed implementation behavior, this should not be relied upon.
577
- class Context {
581
+ class Context : public std ::enable_shared_from_this<Context> {
578
582
public:
579
- // / Create a context .
580
- Context ();
583
+ // / Create a new Context instance .
584
+ static std::shared_ptr< Context> create () { return std::shared_ptr<Context>( new Context ()); }
581
585
582
- // / Destroy the context .
586
+ // / Destructor .
583
587
~Context ();
584
588
585
589
// / Register a region of GPU memory for use in this context.
@@ -606,6 +610,8 @@ class Context {
606
610
std::shared_ptr<Connection> connect (Endpoint localEndpoint, Endpoint remoteEndpoint);
607
611
608
612
private:
613
+ Context ();
614
+
609
615
struct Impl ;
610
616
std::unique_ptr<Impl> pimpl_;
611
617
@@ -620,7 +626,7 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
620
626
// / A class that sets up all registered memories and connections between processes.
621
627
// /
622
628
// / A typical way to use this class:
623
- // / 1. Call connect() to declare connections between the calling process with other processes.
629
+ // / 1. Call connect() to declare connections between the calling process and other processes.
624
630
// / 2. Call registerMemory() to register memory regions that will be used for communication.
625
631
// / 3. Call sendMemory() or recvMemory() to send/receive registered memory regions to/from
626
632
// / other processes.
@@ -670,7 +676,7 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
670
676
// / auto connection = communicator.connect(0, tag, Transport::CudaIpc); // undefined behavior
671
677
// / communicator.sendMemory(memory1, 0, tag);
672
678
// / ```
673
- // / In the wrong example, the connection information from rank 1 will be sent to `mem1` object on rank 0,
679
+ // / In the wrong example, the connection information from rank 1 will be sent to the `mem1` object on rank 0,
674
680
// / where the object type is RegisteredMemory, not Connection.
675
681
// /
676
682
class Communicator {
@@ -762,7 +768,7 @@ class Communicator {
762
768
// / the first get() on the future.
763
769
// / Note that this function is two-way and a connection from rank `i` to remote rank `j` needs
764
770
// / to have a counterpart from rank `j` to rank `i`. Note that with IB, buffers are registered at a page level and if
765
- // / a buffer is spread through multiple pages and do not fully utilize all of them, IB's QP has to register for all
771
+ // / a buffer is spread through multiple pages and does not fully utilize all of them, IB's QP has to register for all
766
772
// / involved pages. This potentially has security risks if the connection's accesses are given to a malicious process.
767
773
// /
768
774
// / Multiple calls to either sendMemory() or connect() with the same @p remoteRank and @p tag will be ordered by
@@ -818,11 +824,11 @@ extern const TransportFlags AllIBTransports;
818
824
// / A constant TransportFlags object representing all transports.
819
825
extern const TransportFlags AllTransports;
820
826
821
- // / A type which could be safely used in device side.
827
+ // / A type which could be safely used on the device side.
822
828
template <class T >
823
829
using DeviceHandle = typename T::DeviceHandle;
824
830
825
- // / Retrieve the deviceHandle instance from host object.
831
+ // / Retrieve the deviceHandle instance from a host object.
826
832
template <typename T>
827
833
DeviceHandle<std::remove_reference_t <T>> deviceHandle (T&& t) {
828
834
return t.deviceHandle ();
0 commit comments