Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
443bfa8
more
fzyzcjy Jun 17, 2025
b986cce
more
fzyzcjy Jun 17, 2025
3ea6f58
more
fzyzcjy Jun 17, 2025
5d3513b
more
fzyzcjy Jun 17, 2025
bda5695
more
fzyzcjy Jun 17, 2025
3740762
more
fzyzcjy Jun 17, 2025
ad4aee8
more
fzyzcjy Jun 17, 2025
b5e4aad
more
fzyzcjy Jun 17, 2025
240d058
more
fzyzcjy Jun 17, 2025
5379d59
more
fzyzcjy Jun 17, 2025
4fc8e79
more
fzyzcjy Jun 17, 2025
2e90afe
more
fzyzcjy Jun 17, 2025
3639a57
more
fzyzcjy Jun 17, 2025
4ef8f05
more
fzyzcjy Jun 17, 2025
047656e
more
fzyzcjy Jun 17, 2025
c21f36d
more
fzyzcjy Jun 17, 2025
7f3e4c0
more
fzyzcjy Jun 17, 2025
92fb573
more
fzyzcjy Jun 17, 2025
29f86f3
more
fzyzcjy Jun 17, 2025
5557e70
more
fzyzcjy Jun 17, 2025
9fd34e7
more
fzyzcjy Jun 17, 2025
6417393
more
fzyzcjy Jun 17, 2025
faaeaad
more
fzyzcjy Jun 17, 2025
c38dbed
more
fzyzcjy Jun 17, 2025
dc74c0a
more
fzyzcjy Jun 17, 2025
61dea30
more
fzyzcjy Jun 17, 2025
7d4bc93
more
fzyzcjy Jun 17, 2025
5b78f22
more
fzyzcjy Jun 17, 2025
75351cd
more
fzyzcjy Jun 17, 2025
7bb12d4
more
fzyzcjy Jun 17, 2025
0e5a155
more
fzyzcjy Jun 17, 2025
87b3980
more
fzyzcjy Jun 17, 2025
4398b5c
more
fzyzcjy Jun 17, 2025
d7e9ce3
more
fzyzcjy Jun 17, 2025
5b83cb8
more
fzyzcjy Jun 17, 2025
f024df5
more
fzyzcjy Jun 17, 2025
5a7b2f2
more
fzyzcjy Jun 17, 2025
6052379
more
fzyzcjy Jun 17, 2025
befcd27
more
fzyzcjy Jun 17, 2025
df598ea
more
fzyzcjy Jun 17, 2025
5b23a8a
more
fzyzcjy Jun 17, 2025
210e499
more
fzyzcjy Jun 17, 2025
379ac24
more
fzyzcjy Jun 17, 2025
43999dc
more
fzyzcjy Jun 17, 2025
7916011
more
fzyzcjy Jun 17, 2025
0525f8f
more
fzyzcjy Jun 17, 2025
c1d3606
Merge branch 'main-upstream_public' into feat/cu_mem_api
fzyzcjy Sep 1, 2025
2bf764c
support NVFP4 data format in low latency dispatch
shifangx Jul 25, 2025
d320aaa
add support fp32_vec_to_e2m1 for __CUDA_ARCH__ less than 1000
shifangx Aug 29, 2025
d88e77e
change threshold for diff
shifangx Aug 29, 2025
3a28b71
add debug message
shifangx Aug 29, 2025
2add019
change physical layout to be (l, m/128, k/4, 32, 4, 4)
shifangx Aug 28, 2025
9d9e395
use global scale for entire dispatch instead of per token scale
shifangx Sep 1, 2025
ccf4eaf
change test case
shifangx Sep 3, 2025
82147f2
change some names and dtype:
shifangx Sep 5, 2025
fc15ca6
support padding m
shifangx Sep 6, 2025
d89a25b
calibrate nvfp4 scale layout with grouped gemm
shifangx Sep 9, 2025
87c9f8f
Fix wrong accuracy
fzyzcjy Sep 12, 2025
769f8e6
Merge branch 'feat/cu_mem_api' into feat/dev_20250914
fzyzcjy Sep 14, 2025
92a14cc
copy kernel and modify upper level
fzyzcjy Sep 14, 2025
e00553c
more cherry pick
fzyzcjy Sep 14, 2025
1fd57b0
rm maxnreg
fzyzcjy Sep 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 175 additions & 13 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,131 @@
#include "kernels/api.cuh"
#include "kernels/configs.cuh"

namespace shared_memory {
void cu_mem_set_access_all(void* ptr, size_t size) {
int device_count;
CUDA_CHECK(cudaGetDeviceCount(&device_count));

CUmemAccessDesc access_desc[device_count];
for (int idx = 0; idx < device_count; ++idx) {
access_desc[idx].location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access_desc[idx].location.id = idx;
access_desc[idx].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
}

CU_CHECK(cuMemSetAccess((CUdeviceptr)ptr, size, access_desc, device_count));
}

void cu_mem_free(void* ptr) {
CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr));

size_t size = 0;
CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));

CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size));
CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size));
CU_CHECK(cuMemRelease(handle));
}

size_t get_size_align_to_granularity(size_t size_raw, size_t granularity) {
size_t size = (size_raw + granularity - 1) & ~(granularity - 1);
if (size == 0) size = granularity;
return size;
}

bool support_fabric() {
int device_count;
CUDA_CHECK(cudaGetDeviceCount(&device_count));

for (int device = 0; device < device_count; ++device) {
int support = 0;
CU_CHECK(cuDeviceGetAttribute(&support, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device));
if (!support) {
return false;
}
}

return true;
}

SharedMemoryAllocator::SharedMemoryAllocator() : enable_fabric(support_fabric()) {}

void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) {
if (enable_fabric) {
CUdevice device;
CU_CHECK(cuCtxGetDevice(&device));

CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;
prop.location.id = device;

size_t granularity = 0;
CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));

size_t size = get_size_align_to_granularity(size_raw, granularity);

CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemCreate(&handle, size, &prop, 0));

CU_CHECK(cuMemAddressReserve((CUdeviceptr *)ptr, size, granularity, 0, 0));
CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0));
cu_mem_set_access_all(*ptr, size);
} else {
CUDA_CHECK(cudaMalloc(ptr, size_raw));
}
}

void SharedMemoryAllocator::free(void* ptr) {
if (enable_fabric) {
cu_mem_free(ptr);
} else {
CUDA_CHECK(cudaFree(ptr));
}
}

void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) {
size_t size = 0;
CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));

mem_handle->size = size;

if (enable_fabric) {
CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr));

CU_CHECK(cuMemExportToShareableHandle(&mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0));
} else {
CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr));
}
}

void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) {
if (enable_fabric) {
size_t size = mem_handle->size;

CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemImportFromShareableHandle(&handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC));

CU_CHECK(cuMemAddressReserve((CUdeviceptr *)ptr, size, 0, 0, 0));
CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0));
cu_mem_set_access_all(*ptr, size);
} else {
CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess));
}
}

void SharedMemoryAllocator::close_mem_handle(void* ptr) {
if (enable_fabric) {
cu_mem_free(ptr);
} else {
CUDA_CHECK(cudaIpcCloseMemHandle(ptr));
}
}
}

namespace deep_ep {

Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy):
Expand Down Expand Up @@ -46,8 +171,8 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_

if (num_nvl_bytes > 0) {
// Local IPC: alloc local memory and set local IPC handles
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes));
CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
shared_memory_allocator.malloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes);
shared_memory_allocator.get_mem_handle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]);
buffer_ptrs_gpu = reinterpret_cast<void**>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes);

// Set barrier signals
Expand Down Expand Up @@ -115,7 +240,8 @@ int Buffer::get_local_device_id() const {
}

pybind11::bytearray Buffer::get_local_ipc_handle() const {
return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE};
const shared_memory::MemHandle& handle = ipc_handles[nvl_rank];
return {reinterpret_cast<const char*>(&handle), sizeof(handle)};
}

pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
Expand Down Expand Up @@ -154,11 +280,11 @@ void Buffer::destroy() {
// Close remote IPC
if (is_available()) {
for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank)
CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i]));
shared_memory_allocator.close_mem_handle(buffer_ptrs[i]);
}

// Free local buffer and error flag
CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank]));
shared_memory_allocator.free(buffer_ptrs[nvl_rank]);
}

// Free NVSHMEM
Expand Down Expand Up @@ -194,13 +320,13 @@ void Buffer::sync(const std::vector<int> &device_ids,
for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) {
EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value());
auto handle_str = std::string(all_gathered_handles[offset + i].value());
EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE);
EP_HOST_ASSERT(handle_str.size() == shared_memory::HANDLE_SIZE);
if (offset + i != rank) {
std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE);
CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess));
std::memcpy(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE);
shared_memory_allocator.open_mem_handle(&buffer_ptrs[i], &ipc_handles[i]);
barrier_signal_ptrs[i] = reinterpret_cast<int*>(static_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
} else {
EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0);
EP_HOST_ASSERT(std::memcmp(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE) == 0);
}
}

Expand Down Expand Up @@ -1091,8 +1217,10 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
const std::optional<torch::Tensor>& x_global_scale,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool use_nvfp4, bool use_ue8m0_for_sf,
bool async, bool return_recv_hook) {
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT(low_latency_mode);
Expand Down Expand Up @@ -1137,8 +1265,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait(launch_stream, compute_stream);

// Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16));
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, use_nvfp4 ? hidden / 2 : hidden},
x.options().dtype(use_nvfp4 ? torch::kUInt8 : (use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
Expand All @@ -1148,6 +1276,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
void* packed_recv_x_scales_ptr = nullptr;
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");

EP_HOST_ASSERT(not (use_fp8 and use_nvfp4));
if (use_fp8) {
// TODO: support unaligned cases
EP_HOST_ASSERT(hidden % 512 == 0);
Expand All @@ -1161,6 +1290,35 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
}
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
}else if (use_nvfp4) {
constexpr int kNumPerChannels = 16;
constexpr int NUM_SF_ELEMS_PER_PACK = 4;
constexpr int mTileSize_dim_0 = 32;
constexpr int mTileSize_dim_1 = 4;
constexpr int mTileSize = mTileSize_dim_0 * mTileSize_dim_1;

assert(hidden % kNumPerChannels == 0);
auto l = num_local_experts;
auto m = num_ranks * num_max_dispatch_tokens_per_rank;
auto rm = (m + 127) / 128;
auto rk = (hidden + (kNumPerChannels * NUM_SF_ELEMS_PER_PACK) -1 ) / (kNumPerChannels * NUM_SF_ELEMS_PER_PACK);
// The physical layout is (l, rm, rk, 32, 4, 4).
if (use_ue8m0_for_sf) {
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
torch::dtype(torch::kInt).device(torch::kCUDA));
} else {
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
torch::dtype(torch::kFloat8_e4m3fn).device(torch::kCUDA));
}
// After permute, the logical shape is (32, 4, rm, 4, rk, l)
packed_recv_x_scales = packed_recv_x_scales.value().permute({3, 4, 1, 5, 2, 0});

// The physical layout is (l, m, k // 2).
// After permute, the logical shape is (m, k // 2, l).
packed_recv_x = packed_recv_x.permute({1, 2, 0});

packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
EP_HOST_ASSERT(packed_recv_x_scales_ptr != nullptr);
}

// Kernel launch
Expand All @@ -1171,13 +1329,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
packed_recv_count.data_ptr<int>(),
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
x_global_scale.has_value() ? x_global_scale->data_ptr<float>() : nullptr,
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0,
use_nvfp4, use_ue8m0_for_sf,
workspace, num_device_sms,
launch_stream, phases);
};
Expand Down Expand Up @@ -1212,7 +1372,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) {
const std::optional<torch::Tensor>& out,
bool overlap, const std::optional<torch::Tensor>& src_signals, uint32_t src_signal_expect_value) {
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT(low_latency_mode);

Expand Down Expand Up @@ -1282,7 +1443,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
num_topk, num_experts, rank, num_ranks,
use_logfmt,
workspace, num_device_sms,
launch_stream, phases, zero_copy);
launch_stream, phases, zero_copy,
overlap, src_signals.has_value() ? src_signals->data_ptr<uint32_t>() : nullptr, src_signal_expect_value);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));

Expand Down
36 changes: 34 additions & 2 deletions csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,33 @@
#define TORCH_EXTENSION_NAME deep_ep_cpp
#endif

namespace shared_memory {

union MemHandleInner {
cudaIpcMemHandle_t cuda_ipc_mem_handle;
CUmemFabricHandle cu_mem_fabric_handle;
};

struct MemHandle {
MemHandleInner inner;
size_t size;
};

constexpr size_t HANDLE_SIZE = sizeof(MemHandle);

class SharedMemoryAllocator {
public:
SharedMemoryAllocator();
void malloc(void** ptr, size_t size);
void free(void* ptr);
void get_mem_handle(MemHandle* mem_handle, void* ptr);
void open_mem_handle(void** ptr, MemHandle* mem_handle);
void close_mem_handle(void* ptr);
private:
bool enable_fabric;
};
}

namespace deep_ep {

struct Buffer {
Expand All @@ -44,7 +71,7 @@ struct Buffer {
int num_device_sms;
int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks;
cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS];

// Stream for communication
at::cuda::CUDAStream comm_stream;
Expand Down Expand Up @@ -76,6 +103,8 @@ struct Buffer {
volatile int* moe_recv_rdma_counter = nullptr;
int* moe_recv_rdma_counter_mapped = nullptr;

shared_memory::SharedMemoryAllocator shared_memory_allocator;

public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy);

Expand Down Expand Up @@ -147,8 +176,10 @@ struct Buffer {
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
const std::optional<torch::Tensor>& x_global_scale,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool use_nvfp4, bool use_ue8m0_for_sf,
bool async, bool return_recv_hook);

std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Expand All @@ -157,7 +188,8 @@ struct Buffer {
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);
const std::optional<torch::Tensor>& out = std::nullopt,
bool overlap = false, const std::optional<torch::Tensor>& src_signals = std::nullopt, uint32_t src_signal_expect_value = 0);

torch::Tensor
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const;
Expand Down
5 changes: 4 additions & 1 deletion csrc/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,14 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
const float* x_global_scale,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool use_nvfp4, bool use_ue8m0_for_sf,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases);

Expand All @@ -163,7 +165,8 @@ void combine(void* combined_x,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases, bool zero_copy);
cudaStream_t stream, int phases, bool zero_copy,
bool overlap, uint32_t* src_signals, uint32_t src_signal_expect_value);

} // namespace internode_ll

Expand Down
Loading