Skip to content

Fix memory leak bug #534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions include/mscclpp/port_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#ifndef MSCCLPP_PORT_CHANNEL_HPP_
#define MSCCLPP_PORT_CHANNEL_HPP_

#include <set>

#include "core.hpp"
#include "port_channel_device.hpp"
#include "proxy.hpp"
Expand Down Expand Up @@ -45,6 +47,13 @@ class ProxyService : public BaseProxyService {
/// @return The ID of the memory region.
MemoryId addMemory(RegisteredMemory memory);

/// Unregister a memory region from the proxy service.
/// @note It is the caller’s responsibility to manage memory lifetimes safely.
/// ProxyService only ensures that memory remains valid while it is in use by the service;
/// other peers may still hold references to that memory beyond this scope.
/// @param memoryId The ID of the memory region to unregister.
void removeMemory(MemoryId memoryId);

/// Get a semaphore by ID.
/// @param id The ID of the semaphore.
/// @return The semaphore.
Expand Down Expand Up @@ -72,8 +81,10 @@ class ProxyService : public BaseProxyService {
std::vector<std::shared_ptr<Host2DeviceSemaphore>> semaphores_;
std::vector<RegisteredMemory> memories_;
std::shared_ptr<Proxy> proxy_;
int deviceNumaNode;
std::unordered_map<std::shared_ptr<Connection>, int> inflightRequests;
std::set<MemoryId> reusableMemoryIds_;
int deviceNumaNode_;
std::unordered_map<std::shared_ptr<Connection>, int> inflightRequests_;
std::atomic_flag lock_;

void bindThread();

Expand Down
10 changes: 10 additions & 0 deletions include/mscclpp/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ std::string getIBDeviceName(Transport ibTransport);
/// @return The InfiniBand transport associated with the specified device name.
Transport getIBTransportByDeviceName(const std::string& ibDeviceName);

/// A simple spinlock implementation using std::atomic_flag.
/// It is used to protect shared resources in a multi-threaded environment.
class SpinLock {
public:
SpinLock(std::atomic_flag& flag, bool yield = true);
~SpinLock();

private:
std::atomic_flag& flag_;
};
} // namespace mscclpp

#endif // MSCCLPP_UTILS_HPP_
5 changes: 4 additions & 1 deletion src/ib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
MSCCLPP_CUTHROW(cuCtxGetDevice(&dev));
MSCCLPP_CUTHROW(cuDeviceGetAttribute(&dmaBufSupported, CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED, dev));
#endif // !defined(__HIP_PLATFORM_AMD__)
if (cuMemAlloc && dmaBufSupported) {
if (cuMemAlloc) {
#if !defined(__HIP_PLATFORM_AMD__)
if (!dmaBufSupported) {
throw mscclpp::Error("Please make sure dma buffer is supported by the device", ErrorCode::InvalidUsage);
}
int fd;
MSCCLPP_CUTHROW(cuMemGetHandleForAddressRange(&fd, addr, pages * pageSize, CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0));

Expand Down
3 changes: 2 additions & 1 deletion src/nvls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ NvlsConnection::Impl::Impl(const std::vector<char>& data) {
}

NvlsConnection::Impl::~Impl() {
// we don't need to free multicast handle object according to NCCL.
// Please ensure that all memory mappings are unmapped from the handle before calling the connection destructor.
cuMemRelease(mcHandle_);
if (rootPid_ == getpid()) {
close(mcFileDesc_);
}
Expand Down
40 changes: 31 additions & 9 deletions src/port_channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,49 @@ MSCCLPP_API_CPP PortChannel::PortChannel(SemaphoreId semaphoreId, std::shared_pt

MSCCLPP_API_CPP ProxyService::ProxyService(size_t fifoSize)
: proxy_(std::make_shared<Proxy>([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); },
[&]() { bindThread(); }, fifoSize)) {
[&]() { bindThread(); }, fifoSize)),
lock_(false) {
int cudaDevice;
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
deviceNumaNode = getDeviceNumaNode(cudaDevice);
deviceNumaNode_ = getDeviceNumaNode(cudaDevice);
}

MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator,
std::shared_ptr<Connection> connection) {
SpinLock spin(lock_);
semaphores_.push_back(std::make_shared<Host2DeviceSemaphore>(communicator, connection));
return semaphores_.size() - 1;
}

MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr<Host2DeviceSemaphore> semaphore) {
SpinLock spin(lock_);
semaphores_.push_back(semaphore);
return semaphores_.size() - 1;
}

MSCCLPP_API_CPP MemoryId ProxyService::addMemory(RegisteredMemory memory) {
SpinLock spin(lock_);
if (!reusableMemoryIds_.empty()) {
auto it = reusableMemoryIds_.begin();
MemoryId memoryId = *it;
reusableMemoryIds_.erase(it);
memories_[memoryId] = memory;
return memoryId;
}
memories_.push_back(memory);
return memories_.size() - 1;
}

MSCCLPP_API_CPP void ProxyService::removeMemory(MemoryId memoryId) {
SpinLock spin(lock_);
if (reusableMemoryIds_.find(memoryId) != reusableMemoryIds_.end() || memoryId >= memories_.size()) {
WARN("Attempted to remove a memory that is not registered or already removed: %u", memoryId);
return;
}
memories_[memoryId] = RegisteredMemory();
Copy link
Contributor

@chhwang chhwang Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a resource lock so that the removal doesn't happen while there exists unflushed triggers on the flight in the proxy. Also, we need a mechanism that prevents proxies on remote ranks from processing requests on this RegisteredMemory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think users are responsible for safely releasing the buffer. Before removing it, ensure that no peers or device-side operations are still accessing this memory. Adding spinlock here to prevent race conditions inside the proxy_channel.

reusableMemoryIds_.insert(memoryId);
}

MSCCLPP_API_CPP std::shared_ptr<Host2DeviceSemaphore> ProxyService::semaphore(SemaphoreId id) const {
return semaphores_[id];
}
Expand All @@ -59,13 +80,14 @@ MSCCLPP_API_CPP void ProxyService::startProxy() { proxy_->start(); }
MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_->stop(); }

MSCCLPP_API_CPP void ProxyService::bindThread() {
if (deviceNumaNode >= 0) {
numaBind(deviceNumaNode);
INFO(MSCCLPP_INIT, "NUMA node of ProxyService proxy thread is set to %d", deviceNumaNode);
if (deviceNumaNode_ >= 0) {
numaBind(deviceNumaNode_);
INFO(MSCCLPP_INIT, "NUMA node of ProxyService proxy thread is set to %d", deviceNumaNode_);
}
}

ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
SpinLock spin(lock_, false);
ChannelTrigger* trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger->fields.semaphoreId];

Expand All @@ -77,19 +99,19 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
RegisteredMemory& src = memories_[trigger->fields.srcMemoryId];
semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset,
trigger->fields.size);
inflightRequests[semaphore->connection()]++;
inflightRequests_[semaphore->connection()]++;
}

if (trigger->fields.type & TriggerFlag) {
semaphore->signal();
inflightRequests[semaphore->connection()]++;
inflightRequests_[semaphore->connection()]++;
}

if (trigger->fields.type & TriggerSync ||
(maxWriteQueueSize != -1 && inflightRequests[semaphore->connection()] > maxWriteQueueSize)) {
(maxWriteQueueSize != -1 && inflightRequests_[semaphore->connection()] > maxWriteQueueSize)) {
semaphore->connection()->flush();
result = ProxyHandlerResult::FlushFifoTailAndContinue;
inflightRequests[semaphore->connection()] = 0;
inflightRequests_[semaphore->connection()] = 0;
}

return result;
Expand Down
13 changes: 9 additions & 4 deletions src/registered_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,15 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {

RegisteredMemory::Impl::~Impl() {
// Close the CUDA IPC handle if it was opened during deserialization
if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash && getPidHash() != this->pidHash) {
if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash) {
if (getPidHash() == this->pidHash) {
// For local registered memory
if (fileDesc >= 0) {
close(fileDesc);
fileDesc = -1;
}
return;
}
void* base = static_cast<char*>(data) - getTransportInfo(Transport::CudaIpc).cudaIpcOffsetFromBase;
if (this->isCuMemMapAlloc) {
CUmemGenericAllocationHandle handle;
Expand All @@ -288,9 +296,6 @@ RegisteredMemory::Impl::~Impl() {
MSCCLPP_CULOG_WARN(cuMemUnmap((CUdeviceptr)base, size));
MSCCLPP_CULOG_WARN(cuMemRelease(handle));
MSCCLPP_CULOG_WARN(cuMemAddressFree((CUdeviceptr)base, size));
if (getNvlsMemHandleType() == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR && fileDesc >= 0) {
close(fileDesc);
}
} else {
cudaError_t err = cudaIpcCloseMemHandle(base);
if (err != cudaSuccess) {
Expand Down
10 changes: 10 additions & 0 deletions src/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,14 @@ std::string getHostName(int maxlen, const char delim) {
return hostname.substr(0, i);
}

SpinLock::SpinLock(std::atomic_flag& flag, bool yield) : flag_(flag) {
while (flag_.test_and_set(std::memory_order_acq_rel)) {
if (yield) {
std::this_thread::yield();
}
}
}

SpinLock::~SpinLock() { flag_.clear(std::memory_order_release); }

} // namespace mscclpp
12 changes: 12 additions & 0 deletions test/unit/core_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
#include <gtest/gtest.h>

#include <mscclpp/core.hpp>
#include <mscclpp/port_channel.hpp>

class LocalCommunicatorTest : public ::testing::Test {
protected:
void SetUp() override {
bootstrap = std::make_shared<mscclpp::TcpBootstrap>(0, 1);
bootstrap->initialize(bootstrap->createUniqueId());
comm = std::make_shared<mscclpp::Communicator>(bootstrap);
proxyService = std::make_shared<mscclpp::ProxyService>();
}

std::shared_ptr<mscclpp::TcpBootstrap> bootstrap;
std::shared_ptr<mscclpp::Communicator> comm;
std::shared_ptr<mscclpp::ProxyService> proxyService;
};

TEST_F(LocalCommunicatorTest, RegisterMemory) {
Expand All @@ -36,3 +39,12 @@ TEST_F(LocalCommunicatorTest, SendMemoryToSelf) {
EXPECT_EQ(sameMemory.size(), memory.size());
EXPECT_EQ(sameMemory.transports(), memory.transports());
}

TEST_F(LocalCommunicatorTest, ProxyServiceAddRemoveMemory) {
auto memory = mscclpp::RegisteredMemory();
auto memoryId = proxyService->addMemory(memory);
EXPECT_EQ(memoryId, 0);
proxyService->removeMemory(memoryId);
memoryId = proxyService->addMemory(memory);
EXPECT_EQ(memoryId, 0);
}
Loading