Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 86 additions & 46 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1524,17 +1524,18 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
#endif
}

std::tuple<torch::Tensor,
std::tuple<std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
torch::Tensor,
torch::Tensor,
torch::Tensor,
std::optional<EventHandle>,
std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x,
const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
bool use_expert_overlap, int num_rounds, int round_id, int send_num_sms, int recv_num_sms, bool hook_use_comm_stream,
int num_max_dispatch_tokens_per_rank,
int num_experts,
bool use_fp8,
Expand Down Expand Up @@ -1566,6 +1567,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x,
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks);
}

if (use_expert_overlap) {
EP_HOST_ASSERT(num_rounds >= 1 && num_experts / num_ranks % num_rounds == 0 && return_recv_hook);
if (send_num_sms == -1) send_num_sms = num_device_sms;
if (recv_num_sms == -1) recv_num_sms = num_device_sms;
}

auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_topk = static_cast<int>(topk_idx.size(1));
auto num_local_experts = num_experts / num_ranks;
Expand All @@ -1574,7 +1581,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x,
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^ 1];

if (not use_expert_overlap || round_id == (num_rounds - 1)) {
// Buffer control
low_latency_buffer_idx ^= 1;
}

// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
Expand All @@ -1585,42 +1597,55 @@ Buffer::low_latency_dispatch(const torch::Tensor& x,
stream_wait(launch_stream, compute_stream);

// Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16));
auto packed_recv_src_info =
torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));

// Allocate column-majored scales
auto packed_recv_x = std::optional<torch::Tensor>();
auto packed_recv_src_info = std::optional<torch::Tensor>();
auto packed_recv_layout_range = std::optional<torch::Tensor>();
auto packed_recv_count = std::optional<torch::Tensor>();
auto packed_recv_x_scales = std::optional<torch::Tensor>();
void* packed_recv_x_scales_ptr = nullptr;
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");

if (use_fp8) {
// TODO: support unaligned cases
EP_HOST_ASSERT(hidden % 512 == 0);
if (not use_ue8m0) {
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else {
EP_HOST_ASSERT(round_scale);
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA));

if (not use_expert_overlap or round_id == 0) {
packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16));
packed_recv_src_info =
torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));

// Allocate column-majored scales
void* packed_recv_x_scales_ptr = nullptr;
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");

if (use_fp8) {
// TODO: support unaligned cases
EP_HOST_ASSERT(hidden % 512 == 0);
if (not use_ue8m0) {
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else {
EP_HOST_ASSERT(round_scale);
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA));
}
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
}
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();

cached_packed_recv_x_ptr = packed_recv_x->data_ptr();
cached_packed_recv_x_scales_ptr = packed_recv_x_scales_ptr;
cached_packed_recv_count_ptr = packed_recv_count->data_ptr<int>();
cached_packed_recv_src_info_ptr = packed_recv_src_info->data_ptr<int>();
cached_packed_recv_layout_range_ptr = packed_recv_layout_range->data_ptr<int64_t>();
}

// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
auto launcher = [=](int phases, int num_sms, cudaStream_t stream) {
internode_ll::dispatch(
packed_recv_x.data_ptr(),
packed_recv_x_scales_ptr,
packed_recv_src_info.data_ptr<int>(),
packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
cached_packed_recv_x_ptr,
cached_packed_recv_x_scales_ptr,
cached_packed_recv_src_info_ptr,
cached_packed_recv_layout_range_ptr,
cached_packed_recv_count_ptr,
mask_buffer_ptr,
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
Expand All @@ -1642,11 +1667,14 @@ Buffer::low_latency_dispatch(const torch::Tensor& x,
round_scale,
use_ue8m0,
workspace,
num_device_sms,
launch_stream,
phases);
num_sms,
stream,
phases,
use_expert_overlap,
num_rounds,
round_id);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE), use_expert_overlap ? send_num_sms : num_device_sms, launch_stream);

// Wait streams
std::optional<EventHandle> event;
Expand All @@ -1661,7 +1689,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x,
// Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE, use_expert_overlap ? recv_num_sms : num_device_sms, hook_use_comm_stream ? comm_stream : launch_stream); };

// Return values
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
Expand All @@ -1684,6 +1712,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
bool zero_copy,
bool async,
bool return_recv_hook,
bool use_expert_overlap, int num_rounds, int send_round_id, int send_num_sms, int recv_num_sms, bool hook_use_comm_stream,
const std::optional<torch::Tensor>& out) {
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT(low_latency_mode);
Expand Down Expand Up @@ -1711,6 +1740,12 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
EP_HOST_ASSERT(combine_wait_recv_cost_stats->size(0) == num_ranks);
}

if (use_expert_overlap) {
EP_HOST_ASSERT(num_rounds >= 1 && num_experts / num_ranks % num_rounds == 0 && return_recv_hook);
if (send_num_sms == -1) send_num_sms = num_device_sms;
if (recv_num_sms == -1) recv_num_sms = num_device_sms;
}

auto hidden = static_cast<int>(x.size(2));
auto num_topk = static_cast<int>(topk_weights.size(1));
auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
Expand All @@ -1719,7 +1754,11 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^ 1];

if (not use_expert_overlap || send_round_id == (num_rounds - 1)) {
low_latency_buffer_idx ^= 1;
}

// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
Expand All @@ -1742,7 +1781,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio

// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
auto launcher = [=](int phases, int num_sms, cudaStream_t stream, int num_rounds, int send_round_id) {
internode_ll::combine(combined_x.data_ptr(),
buffer.combine_rdma_recv_data_buffer,
buffer.combine_rdma_recv_flag_buffer,
Expand All @@ -1765,12 +1804,13 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
num_ranks,
use_logfmt,
workspace,
num_device_sms,
launch_stream,
num_sms,
stream,
phases,
zero_copy);
zero_copy,
use_expert_overlap, num_rounds, send_round_id);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE), use_expert_overlap ? send_num_sms : num_device_sms, launch_stream, num_rounds, send_round_id);

// Wait streams
std::optional<EventHandle> event;
Expand All @@ -1785,7 +1825,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
// Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE, use_expert_overlap ? recv_num_sms : num_device_sms, hook_use_comm_stream ? comm_stream : launch_stream, 1, 0); };

// Return values
return {combined_x, event, recv_hook};
Expand Down
17 changes: 13 additions & 4 deletions csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ struct Buffer {
// Workspace
void* workspace = nullptr;

// cached per-expert overlap pointers
void* cached_packed_recv_x_ptr = nullptr;
void* cached_packed_recv_x_scales_ptr = nullptr;
int* cached_packed_recv_count_ptr = nullptr;
int* cached_packed_recv_src_info_ptr = nullptr;
int64_t* cached_packed_recv_layout_range_ptr = nullptr;

// Host-side MoE info
volatile int* moe_recv_counter = nullptr;
int* moe_recv_counter_mapped = nullptr;
Expand Down Expand Up @@ -254,17 +261,18 @@ struct Buffer {

void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);

std::tuple<torch::Tensor,
std::tuple<std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
torch::Tensor,
torch::Tensor,
torch::Tensor,
std::optional<EventHandle>,
std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x,
const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
bool use_expert_overlap, int num_rounds, int round_id,int send_num_sms, int recv_num_sms, bool hook_use_comm_stream,
int num_max_dispatch_tokens_per_rank,
int num_experts,
bool use_fp8,
Expand All @@ -286,6 +294,7 @@ struct Buffer {
bool zero_copy,
bool async,
bool return_recv_hook,
bool use_expert_overlap, int num_rounds, int send_round_id, int send_num_sms, int recv_num_sms, bool hook_use_comm_stream,
const std::optional<torch::Tensor>& out = std::nullopt);

torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const;
Expand Down
4 changes: 3 additions & 1 deletion csrc/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ torch::Event create_event(const at::cuda::CUDAStream& s) {
}

void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) {
EP_HOST_ASSERT(s_0.id() != s_1.id());
if (s_0.id() == s_1.id()) {
return;
}
s_0.unwrap().wait(create_event(s_1));
}

Expand Down
10 changes: 8 additions & 2 deletions csrc/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,10 @@ void dispatch(void* packed_recv_x,
void* workspace,
int num_device_sms,
cudaStream_t stream,
int phases);
int phases,
bool use_expert_overlap,
int num_rounds,
int round_id);

void combine(void* combined_x,
void* rdma_recv_x,
Expand All @@ -337,7 +340,10 @@ void combine(void* combined_x,
int num_device_sms,
cudaStream_t stream,
int phases,
bool zero_copy);
bool zero_copy,
bool use_expert_overlap,
int num_rounds,
int round_id);

void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* output_mask_tensor, cudaStream_t stream);

Expand Down
Loading