Skip to content

Commit b4dde38

Browse files
authored
FIFO improvements (#557)
* Revert `MSCCLPP_FIFO_USE_TAIL_REPLICA=1` back to the default. * Optimize `FifoDeviceHandle`. * Do not use `cudaHostAllocWriteCombined` that increases latency. * Pin host memory for `Host2DeviceSemaphore::outboundSemaphore_`. * Fix proxy NUMA binding issues. * Prevent graph capture inside proxy threads. * Now `CudaIpcConnection` skips stream sync when unnecessary. * Now any type of connection needs to hold a shared pointer to the context for memory safety. * Now a context should be always managed by a shared pointer for memory safety. * Minor docs & interface improvements. * Minor fix in `mscclpp-test` correctness test.
1 parent 2796cfa commit b4dde38

28 files changed

+386
-355
lines changed

CMakeLists.txt

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -63,36 +63,22 @@ else()
6363
endif()
6464

6565
if(MSCCLPP_GPU_ARCHS)
66-
# Remove any leading/trailing whitespace
6766
string(STRIP "${MSCCLPP_GPU_ARCHS}" MSCCLPP_GPU_ARCHS)
68-
69-
# Split the string into a list
7067
string(REPLACE " " ";" MSCCLPP_GPU_ARCHS "${MSCCLPP_GPU_ARCHS}")
7168
string(REPLACE "," ";" MSCCLPP_GPU_ARCHS "${MSCCLPP_GPU_ARCHS}")
72-
73-
# Check if the list is empty
7469
if(NOT MSCCLPP_GPU_ARCHS)
75-
message(FATAL_ERROR "MSCCLPP_GPU_ARCHS is given empty. Please specify GPU architectures or do not set MSCCLPP_GPU_ARCHS.")
70+
message(FATAL_ERROR "MSCCLPP_GPU_ARCHS is empty. Specify GPU architectures or leave unset.")
7671
endif()
7772
elseif(MSCCLPP_USE_CUDA)
78-
# CUDA 11 or higher is required
79-
if(CUDAToolkit_VERSION_MAJOR LESS 11)
80-
message(FATAL_ERROR "CUDA 11 or higher is required but detected ${CUDAToolkit_VERSION}")
81-
endif()
82-
83-
# Ampere architecture
84-
if(CUDAToolkit_VERSION_MAJOR GREATER_EQUAL 11)
85-
set(MSCCLPP_GPU_ARCHS 80)
73+
if(CUDAToolkit_VERSION VERSION_LESS "11.8")
74+
message(FATAL_ERROR "CUDA 11.8 or higher required, found ${CUDAToolkit_VERSION}")
8675
endif()
87-
88-
# Hopper architecture
89-
if(CUDAToolkit_VERSION_MAJOR GREATER_EQUAL 12)
90-
set(MSCCLPP_GPU_ARCHS ${MSCCLPP_GPU_ARCHS} 90)
76+
set(MSCCLPP_GPU_ARCHS 80)
77+
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.0")
78+
list(APPEND MSCCLPP_GPU_ARCHS 90)
9179
endif()
92-
93-
# Blackwell architecture
94-
if(CUDAToolkit_VERSION_MAJOR GREATER_EQUAL 12 AND CUDAToolkit_VERSION_MINOR GREATER_EQUAL 8)
95-
set(MSCCLPP_GPU_ARCHS ${MSCCLPP_GPU_ARCHS} 100)
80+
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
81+
list(APPEND MSCCLPP_GPU_ARCHS 100)
9682
endif()
9783
elseif(MSCCLPP_USE_ROCM)
9884
set(CMAKE_HIP_ARCHITECTURES gfx90a gfx941 gfx942)

include/mscclpp/core.hpp

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ class TcpBootstrap : public Bootstrap {
127127
/// @return The unique ID stored in the TcpBootstrap.
128128
UniqueId getUniqueId() const;
129129

130-
/// Initialize the TcpBootstrap with a given unique ID. The unique ID can be generated by any methods;
131-
/// it can be created by createUniqueId() or can be any arbitrary bit arrays provided by the user.
130+
/// Initialize the TcpBootstrap with a given unique ID. The unique ID can be generated by any method;
131+
/// it can be created by createUniqueId() or can be any arbitrary bit array provided by the user.
132132
/// @param uniqueId The unique ID to initialize the TcpBootstrap with.
133133
/// @param timeoutSec The connection timeout in seconds.
134134
void initialize(UniqueId uniqueId, int64_t timeoutSec = 30);
@@ -453,7 +453,7 @@ class Endpoint {
453453
/// @return A vector of characters representing the serialized Endpoint object.
454454
std::vector<char> serialize();
455455

456-
/// Deserialize a Endpoint object from a vector of characters.
456+
/// Deserialize an Endpoint object from a vector of characters.
457457
///
458458
/// @param data A vector of characters representing a serialized Endpoint object.
459459
/// @return A deserialized Endpoint object.
@@ -473,8 +473,10 @@ class Connection {
473473
public:
474474
/// Constructor.
475475
/// @param maxWriteQueueSize The maximum number of write requests that can be queued.
476-
Connection(int maxWriteQueueSize) : maxWriteQueueSize(maxWriteQueueSize){};
476+
Connection(std::shared_ptr<Context> context, int maxWriteQueueSize)
477+
: context_(context), maxWriteQueueSize_(maxWriteQueueSize){};
477478

479+
/// Destructor.
478480
virtual ~Connection() = default;
479481

480482
/// Write data from a source RegisteredMemory to a destination RegisteredMemory.
@@ -487,7 +489,7 @@ class Connection {
487489
virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
488490
uint64_t size) = 0;
489491

490-
/// Update a 8-byte value in a destination RegisteredMemory and synchronize the change with the remote process.
492+
/// Update an 8-byte value in a destination RegisteredMemory and synchronize the change with the remote process.
491493
///
492494
/// @param dst The destination RegisteredMemory.
493495
/// @param dstOffset The offset in bytes from the start of the destination RegisteredMemory.
@@ -522,7 +524,9 @@ class Connection {
522524
// Internal methods for getting implementation pointers.
523525
static std::shared_ptr<RegisteredMemory::Impl> getImpl(RegisteredMemory& memory);
524526
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
525-
int maxWriteQueueSize;
527+
528+
std::shared_ptr<Context> context_;
529+
int maxWriteQueueSize_;
526530
};
527531

528532
/// Used to configure an endpoint.
@@ -567,19 +571,19 @@ struct EndpointConfig {
567571
/// 1. The client creates an endpoint with createEndpoint() and sends it to the server.
568572
/// 2. The server receives the client endpoint, creates its own endpoint with createEndpoint(), sends it to the
569573
/// client, and creates a connection with connect().
570-
/// 4. The client receives the server endpoint, creates a connection with connect() and sends a
574+
/// 3. The client receives the server endpoint, creates a connection with connect() and sends a
571575
/// RegisteredMemory to the server.
572-
/// 5. The server receives the RegisteredMemory and writes to it using the previously created connection.
573-
/// The client waiting to create a connection before sending the RegisteredMemory ensures that the server can not
576+
/// 4. The server receives the RegisteredMemory and writes to it using the previously created connection.
577+
/// The client waiting to create a connection before sending the RegisteredMemory ensures that the server cannot
574578
/// write to the RegisteredMemory before the connection is established.
575579
///
576580
/// While some transports may have more relaxed implementation behavior, this should not be relied upon.
577-
class Context {
581+
class Context : public std::enable_shared_from_this<Context> {
578582
public:
579-
/// Create a context.
580-
Context();
583+
/// Create a new Context instance.
584+
static std::shared_ptr<Context> create() { return std::shared_ptr<Context>(new Context()); }
581585

582-
/// Destroy the context.
586+
/// Destructor.
583587
~Context();
584588

585589
/// Register a region of GPU memory for use in this context.
@@ -606,6 +610,8 @@ class Context {
606610
std::shared_ptr<Connection> connect(Endpoint localEndpoint, Endpoint remoteEndpoint);
607611

608612
private:
613+
Context();
614+
609615
struct Impl;
610616
std::unique_ptr<Impl> pimpl_;
611617

@@ -620,7 +626,7 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
620626
/// A class that sets up all registered memories and connections between processes.
621627
///
622628
/// A typical way to use this class:
623-
/// 1. Call connect() to declare connections between the calling process with other processes.
629+
/// 1. Call connect() to declare connections between the calling process and other processes.
624630
/// 2. Call registerMemory() to register memory regions that will be used for communication.
625631
/// 3. Call sendMemory() or recvMemory() to send/receive registered memory regions to/from
626632
/// other processes.
@@ -670,7 +676,7 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will
670676
/// auto connection = communicator.connect(0, tag, Transport::CudaIpc); // undefined behavior
671677
/// communicator.sendMemory(memory1, 0, tag);
672678
/// ```
673-
/// In the wrong example, the connection information from rank 1 will be sent to `mem1` object on rank 0,
679+
/// In the wrong example, the connection information from rank 1 will be sent to the `mem1` object on rank 0,
674680
/// where the object type is RegisteredMemory, not Connection.
675681
///
676682
class Communicator {
@@ -762,7 +768,7 @@ class Communicator {
762768
/// the first get() on the future.
763769
/// Note that this function is two-way and a connection from rank `i` to remote rank `j` needs
764770
/// to have a counterpart from rank `j` to rank `i`. Note that with IB, buffers are registered at a page level and if
765-
/// a buffer is spread through multiple pages and do not fully utilize all of them, IB's QP has to register for all
771+
/// a buffer is spread through multiple pages and does not fully utilize all of them, IB's QP has to register for all
766772
/// involved pages. This potentially has security risks if the connection's accesses are given to a malicious process.
767773
///
768774
/// Multiple calls to either sendMemory() or connect() with the same @p remoteRank and @p tag will be ordered by
@@ -818,11 +824,11 @@ extern const TransportFlags AllIBTransports;
818824
/// A constant TransportFlags object representing all transports.
819825
extern const TransportFlags AllTransports;
820826

821-
/// A type which could be safely used in device side.
827+
/// A type which could be safely used on the device side.
822828
template <class T>
823829
using DeviceHandle = typename T::DeviceHandle;
824830

825-
/// Retrieve the deviceHandle instance from host object.
831+
/// Retrieve the deviceHandle instance from a host object.
826832
template <typename T>
827833
DeviceHandle<std::remove_reference_t<T>> deviceHandle(T&& t) {
828834
return t.deviceHandle();

include/mscclpp/env.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class Env {
9393
/// Env name: `MSCCLPP_FIFO_USE_TAIL_REPLICA`. If set to true, it will replicate the FIFO tail on the GPU memory,
9494
/// which makes the GPU poll on the tail faster, but requires a periodic FIFO flush to update the replica on the GPU.
9595
/// If set to false, the GPU will directly read the tail from the host memory, which is slower but does not require
96-
/// periodic flushes. Default is false.
96+
/// periodic flushes. Default is true.
9797
const bool fifoUseTailReplica;
9898

9999
private:

include/mscclpp/fifo.hpp

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,51 +4,46 @@
44
#ifndef MSCCLPP_FIFO_HPP_
55
#define MSCCLPP_FIFO_HPP_
66

7-
#include <cstdint>
8-
#include <functional>
97
#include <memory>
108

119
#include "fifo_device.hpp"
1210

1311
namespace mscclpp {
1412

15-
constexpr size_t DEFAULT_FIFO_SIZE = 128;
13+
constexpr size_t DEFAULT_FIFO_SIZE = 512;
1614

17-
/// A class representing a host proxy FIFO that can consume work elements pushed by device threads.
15+
/// Host-side proxy FIFO for device-produced work elements.
1816
class Fifo {
1917
public:
20-
/// Constructs a new Fifo object.
21-
/// @param size The number of entires in the FIFO.
18+
/// Constructor.
19+
/// @param size Number of entries (default: DEFAULT_FIFO_SIZE).
2220
Fifo(int size = DEFAULT_FIFO_SIZE);
2321

24-
/// Destroys the Fifo object.
22+
/// Destructor.
2523
~Fifo();
2624

27-
/// Polls the FIFO for a trigger.
28-
///
29-
/// Returns ProxyTrigger which is the trigger at the head of fifo.
25+
/// Poll and get the trigger at the head.
26+
/// @return ProxyTrigger at the head of the FIFO.
3027
ProxyTrigger poll();
3128

32-
/// Pops a trigger from the FIFO.
29+
/// Remove the head trigger.
3330
void pop();
3431

3532
/// Flushes the tail of the FIFO.
36-
///
3733
/// @param sync If true, waits for the flush to complete before returning.
3834
void flushTail(bool sync = false);
3935

40-
/// Return the FIFO size.
41-
/// @return The FIFO size.
36+
/// Get FIFO size.
37+
/// @return Number of entries in the FIFO.
4238
int size() const;
4339

44-
/// Returns a FifoDeviceHandle object representing the device FIFO.
45-
///
46-
/// @return A FifoDeviceHandle object representing the device FIFO.
40+
/// Get device-side FIFO handle.
41+
/// @return FifoDeviceHandle for device access.
4742
FifoDeviceHandle deviceHandle() const;
4843

4944
private:
5045
struct Impl;
51-
std::unique_ptr<Impl> pimpl;
46+
std::unique_ptr<Impl> pimpl_;
5247
};
5348

5449
} // namespace mscclpp

include/mscclpp/fifo_device.hpp

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515

1616
namespace mscclpp {
1717

18-
/// A struct representing a pair of 64-bit unsigned integers used as a trigger for the proxy.
18+
#if defined(MSCCLPP_DEVICE_COMPILE)
19+
MSCCLPP_DEVICE_INLINE uint64_t hostLoadRelaxed(uint64_t* ptr) { return atomicLoad(ptr, memoryOrderRelaxed); }
20+
#endif // defined(MSCCLPP_DEVICE_COMPILE)
21+
22+
/// Pair of 64-bit unsigned integers used as a trigger for the proxy.
1923
///
2024
/// This struct is used as a work element in the concurrent FIFO where multiple device threads can push
2125
/// ProxyTrigger elements and a single host proxy thread consumes these work elements.
@@ -45,68 +49,63 @@ struct alignas(16) ProxyTrigger {
4549
struct FifoDeviceHandle {
4650
#if defined(MSCCLPP_DEVICE_COMPILE)
4751
/// Push a trigger to the FIFO.
48-
///
49-
/// @param trigger The trigger to push.
50-
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
51-
/// @return The new head of the FIFO.
52+
/// @param trigger Trigger to push.
53+
/// @param maxSpinCount Max spin count before assert. Never assert if negative.
54+
/// @return Previous head of the FIFO where the trigger was pushed.
5255
MSCCLPP_DEVICE_INLINE uint64_t push(ProxyTrigger trigger, [[maybe_unused]] int64_t maxSpinCount = 1000000) {
53-
uint64_t curFifoHead = atomicFetchAdd(this->head, (uint64_t)1, memoryOrderRelaxed);
56+
uint64_t prevHead = atomicFetchAdd<uint64_t, scopeDevice>(head, 1, memoryOrderRelaxed);
5457

55-
// make the last bit intentionally non-zero so that we can safely poll. Don't worry, we will change it back in host
56-
// side
57-
trigger.snd ^= ((uint64_t)1 << (uint64_t)63);
58+
// Flip the last bit for safe polling; host will revert.
59+
constexpr uint64_t flipMask = uint64_t{1} << uint64_t{63};
60+
trigger.snd ^= flipMask;
5861

5962
// Only one of two conditions need to be met to proceed. Either the tail has advanced enough or where we need to
6063
// write to is 0. However, the first condition is faster to check since the tail is flushed periodically anyways but
6164
// for the second condition we need to read CPU memory.
6265
// As atomic access is slow, we first check using the bare pointer and then use the atomic load if the
6366
// condition is not met.
64-
if (curFifoHead >= size + *(this->tailReplica)) {
65-
OR_POLL_MAYBE_JAILBREAK((curFifoHead >= size + atomicLoad(this->tailReplica, memoryOrderRelaxed)),
66-
(atomicLoad(&(this->triggers[curFifoHead % size].fst), memoryOrderRelaxed) != 0),
67-
maxSpinCount);
67+
if (prevHead >= size + *tailReplica) {
68+
OR_POLL_MAYBE_JAILBREAK((prevHead >= size + atomicLoad(tailReplica, memoryOrderRelaxed)),
69+
(hostLoadRelaxed(&(triggers[prevHead % size].fst)) != 0), maxSpinCount);
6870
}
6971

70-
ProxyTrigger* triggerPtr = &(this->triggers[curFifoHead % size]);
72+
ProxyTrigger* triggerPtr = &(triggers[prevHead % size]);
7173

72-
// Make sure the data is visible to the host before we update the tail.
7374
#if defined(MSCCLPP_DEVICE_CUDA)
7475
#if __CUDA_ARCH__ == 800
75-
// For A100, threadfence_system is more efficient than release
76+
// This is faster than release for A100.
7677
__threadfence_system();
7778
asm volatile("st.global.relaxed.sys.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
7879
#else
7980
asm volatile("st.global.release.sys.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
8081
#endif
8182
#else // !defined(MSCCLPP_DEVICE_CUDA)
82-
// store snd no later than fst.
83+
// Store snd no later than fst.
8384
atomicStore(&(triggerPtr->snd), trigger.snd, memoryOrderRelaxed);
8485
atomicStore(&(triggerPtr->fst), trigger.fst, memoryOrderRelease);
8586
#endif // !defined(MSCCLPP_DEVICE_CUDA)
8687

87-
return curFifoHead;
88+
return prevHead;
8889
}
8990

90-
/// Wait until there is a place in the FIFO to push a trigger.
91-
///
92-
/// @param curFifoHead The current head of the FIFO.
93-
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
94-
MSCCLPP_DEVICE_INLINE void sync(uint64_t curFifoHead, [[maybe_unused]] int64_t maxSpinCount = 1000000) {
95-
// Same as push but in this case checking the fist condition is probably faster since for tail to be pushed we need
91+
/// Wait until a specific trigger is popped from the FIFO.
92+
/// @param fifoHead FIFO head where the trigger was pushed.
93+
/// @param maxSpinCount Max spin count before assert. Never assert if negative.
94+
MSCCLPP_DEVICE_INLINE void sync(uint64_t fifoHead, [[maybe_unused]] int64_t maxSpinCount = 1000000) {
95+
// Same as push but in this case checking the first condition is probably faster since for tail to be pushed we need
9696
// to wait for cudaMemcpy to be done.
97-
OR_POLL_MAYBE_JAILBREAK((curFifoHead >= atomicLoad(this->tailReplica, memoryOrderRelaxed)),
98-
(atomicLoad(&(this->triggers[curFifoHead % size].fst), memoryOrderRelaxed) != 0),
99-
maxSpinCount);
97+
OR_POLL_MAYBE_JAILBREAK((fifoHead >= atomicLoad(tailReplica, memoryOrderRelaxed)),
98+
(hostLoadRelaxed(&(triggers[fifoHead % size].fst)) != 0), maxSpinCount);
10099
}
101100
#endif // defined(MSCCLPP_DEVICE_COMPILE)
102101

103-
/// The FIFO buffer that is allocated on the host via `cudaHostAlloc()`.
102+
/// FIFO buffer on host.
104103
ProxyTrigger* triggers;
105-
/// Replica of the FIFO tail.
106-
uint64_t* tailReplica;
107-
/// The FIFO head. Allocated on the device and only accessed by the device.
104+
/// FIFO head on device.
108105
uint64_t* head;
109-
/// The FIFO size.
106+
/// FIFO tail replica on device.
107+
uint64_t* tailReplica;
108+
/// FIFO size.
110109
int size;
111110
};
112111

include/mscclpp/gpu_utils.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ namespace detail {
123123
void setReadWriteMemoryAccess(void* base, size_t size);
124124

125125
void* gpuCalloc(size_t bytes);
126-
void* gpuCallocHost(size_t bytes);
126+
void* gpuCallocHost(size_t bytes, unsigned int flags);
127127
#if defined(__HIP_PLATFORM_AMD__)
128128
void* gpuCallocUncached(size_t bytes);
129129
#endif // defined(__HIP_PLATFORM_AMD__)
@@ -206,13 +206,13 @@ auto gpuCallocUnique(size_t nelems = 1) {
206206
}
207207

208208
template <class T>
209-
auto gpuCallocHostShared(size_t nelems = 1) {
210-
return detail::safeAlloc<T, detail::GpuHostDeleter<T>, std::shared_ptr<T>>(detail::gpuCallocHost, nelems);
209+
auto gpuCallocHostShared(size_t nelems = 1, unsigned int flags = cudaHostAllocMapped) {
210+
return detail::safeAlloc<T, detail::GpuHostDeleter<T>, std::shared_ptr<T>>(detail::gpuCallocHost, nelems, flags);
211211
}
212212

213213
template <class T>
214-
auto gpuCallocHostUnique(size_t nelems = 1) {
215-
return detail::safeAlloc<T, detail::GpuHostDeleter<T>, UniqueGpuHostPtr<T>>(detail::gpuCallocHost, nelems);
214+
auto gpuCallocHostUnique(size_t nelems = 1, unsigned int flags = cudaHostAllocMapped) {
215+
return detail::safeAlloc<T, detail::GpuHostDeleter<T>, UniqueGpuHostPtr<T>>(detail::gpuCallocHost, nelems, flags);
216216
}
217217

218218
#if defined(__HIP_PLATFORM_AMD__)

include/mscclpp/memory_channel_device.hpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,6 @@ struct BaseMemoryChannelDeviceHandle {
3535
///
3636
MSCCLPP_DEVICE_INLINE void relaxedSignal() { semaphore_.relaxedSignal(); }
3737

38-
/// Increase the counter of the local semaphore.
39-
MSCCLPP_DEVICE_INLINE void semaphoreIncrement() { semaphore_.semaphoreIncrement(); }
40-
41-
/// Read the counter of the local semaphore.
42-
MSCCLPP_DEVICE_INLINE uint64_t semaphoreGetLocal() const { return semaphore_.semaphoreGetLocal(); }
43-
4438
/// Check if the remote semaphore has signaled.
4539
/// @return true if the remote semaphore has signaled.
4640
MSCCLPP_DEVICE_INLINE bool poll() { return semaphore_.poll(); }

0 commit comments

Comments
 (0)