diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 0789cd58..7fe9769b 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -10,6 +10,131 @@ #include "kernels/api.cuh" #include "kernels/configs.cuh" +namespace shared_memory { +void cu_mem_set_access_all(void* ptr, size_t size) { + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); + + CUmemAccessDesc access_desc[device_count]; + for (int idx = 0; idx < device_count; ++idx) { + access_desc[idx].location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access_desc[idx].location.id = idx; + access_desc[idx].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + } + + CU_CHECK(cuMemSetAccess((CUdeviceptr)ptr, size, access_desc, device_count)); +} + +void cu_mem_free(void* ptr) { + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemRelease(handle)); +} + +size_t get_size_align_to_granularity(size_t size_raw, size_t granularity) { + size_t size = (size_raw + granularity - 1) & ~(granularity - 1); + if (size == 0) size = granularity; + return size; +} + +bool support_fabric() { + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); + + for (int device = 0; device < device_count; ++device) { + int support = 0; + CU_CHECK(cuDeviceGetAttribute(&support, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device)); + if (!support) { + return false; + } + } + + return true; +} + +SharedMemoryAllocator::SharedMemoryAllocator() : enable_fabric(support_fabric()) {} + +void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) { + if (enable_fabric) { + CUdevice device; + CU_CHECK(cuCtxGetDevice(&device)); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; + prop.location.id = device; + + size_t granularity = 0; + CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + + size_t size = get_size_align_to_granularity(size_raw, granularity); + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemCreate(&handle, size, &prop, 0)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr *)ptr, size, granularity, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + cu_mem_set_access_all(*ptr, size); + } else { + CUDA_CHECK(cudaMalloc(ptr, size_raw)); + } +} + +void SharedMemoryAllocator::free(void* ptr) { + if (enable_fabric) { + cu_mem_free(ptr); + } else { + CUDA_CHECK(cudaFree(ptr)); + } +} + +void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) { + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + mem_handle->size = size; + + if (enable_fabric) { + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + + CU_CHECK(cuMemExportToShareableHandle(&mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); + } else { + CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr)); + } +} + +void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) { + if (enable_fabric) { + size_t size = mem_handle->size; + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemImportFromShareableHandle(&handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr *)ptr, size, 0, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + cu_mem_set_access_all(*ptr, size); + } else { + CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess)); + } +} + +void SharedMemoryAllocator::close_mem_handle(void* ptr) { + if (enable_fabric) { + cu_mem_free(ptr); + } else { + CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); + } +} +} + namespace deep_ep { Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy): @@ -46,8 +171,8 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ if (num_nvl_bytes > 0) { // Local IPC: alloc local memory and set local IPC handles - CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes)); - CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); + shared_memory_allocator.malloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes); + shared_memory_allocator.get_mem_handle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]); buffer_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); // Set barrier signals @@ -115,7 +240,8 @@ int Buffer::get_local_device_id() const { } pybind11::bytearray Buffer::get_local_ipc_handle() const { - return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE}; + const shared_memory::MemHandle& handle = ipc_handles[nvl_rank]; + return {reinterpret_cast(&handle), sizeof(handle)}; } pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { @@ -154,11 +280,11 @@ void Buffer::destroy() { // Close remote IPC if (is_available()) { for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank) - CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); + shared_memory_allocator.close_mem_handle(buffer_ptrs[i]); } // Free local buffer and error flag - CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); + shared_memory_allocator.free(buffer_ptrs[nvl_rank]); } // Free NVSHMEM @@ -194,13 +320,13 @@ void Buffer::sync(const std::vector &device_ids, for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) { EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); auto handle_str = std::string(all_gathered_handles[offset + i].value()); - EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE); + EP_HOST_ASSERT(handle_str.size() == shared_memory::HANDLE_SIZE); if (offset + i != rank) { - std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); - CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); + std::memcpy(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE); + shared_memory_allocator.open_mem_handle(&buffer_ptrs[i], &ipc_handles[i]); barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes); } else { - EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); + EP_HOST_ASSERT(std::memcmp(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE) == 0); } } @@ -1091,8 +1217,10 @@ std::tuple, torch::Tensor, torch::Te Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, const std::optional& cumulative_local_expert_recv_stats, const std::optional& dispatch_wait_recv_cost_stats, + const std::optional& x_global_scale, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool round_scale, bool use_ue8m0, + bool use_nvfp4, bool use_ue8m0_for_sf, bool async, bool return_recv_hook) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); @@ -1137,8 +1265,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i stream_wait(launch_stream, compute_stream); // Allocate packed tensors - auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, - x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)); + auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, use_nvfp4 ? hidden / 2 : hidden}, + x.options().dtype(use_nvfp4 ? torch::kUInt8 : (use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16))); auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); @@ -1148,6 +1276,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i void* packed_recv_x_scales_ptr = nullptr; EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); + EP_HOST_ASSERT(not (use_fp8 and use_nvfp4)); if (use_fp8) { // TODO: support unaligned cases EP_HOST_ASSERT(hidden % 512 == 0); @@ -1161,6 +1290,35 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i } packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); + }else if (use_nvfp4) { + constexpr int kNumPerChannels = 16; + constexpr int NUM_SF_ELEMS_PER_PACK = 4; + constexpr int mTileSize_dim_0 = 32; + constexpr int mTileSize_dim_1 = 4; + constexpr int mTileSize = mTileSize_dim_0 * mTileSize_dim_1; + + assert(hidden % kNumPerChannels == 0); + auto l = num_local_experts; + auto m = num_ranks * num_max_dispatch_tokens_per_rank; + auto rm = (m + 127) / 128; + auto rk = (hidden + (kNumPerChannels * NUM_SF_ELEMS_PER_PACK) -1 ) / (kNumPerChannels * NUM_SF_ELEMS_PER_PACK); + // The physical layout is (l, rm, rk, 32, 4, 4). + if (use_ue8m0_for_sf) { + packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4}, + torch::dtype(torch::kInt).device(torch::kCUDA)); + } else { + packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4}, + torch::dtype(torch::kFloat8_e4m3fn).device(torch::kCUDA)); + } + // After permute, the logical shape is (32, 4, rm, 4, rk, l) + packed_recv_x_scales = packed_recv_x_scales.value().permute({3, 4, 1, 5, 2, 0}); + + // The physical layout is (l, m, k // 2). + // After permute, the logical shape is (m, k // 2, l). + packed_recv_x = packed_recv_x.permute({1, 2, 0}); + + packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); + EP_HOST_ASSERT(packed_recv_x_scales_ptr != nullptr); } // Kernel launch @@ -1171,6 +1329,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i packed_recv_count.data_ptr(), cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr() : nullptr, dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr() : nullptr, + x_global_scale.has_value() ? x_global_scale->data_ptr() : nullptr, buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, buffer.dispatch_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), @@ -1178,6 +1337,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, use_fp8, round_scale, use_ue8m0, + use_nvfp4, use_ue8m0_for_sf, workspace, num_device_sms, launch_stream, phases); }; @@ -1212,7 +1372,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id const std::optional& combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, - const std::optional& out) { + const std::optional& out, + bool overlap, const std::optional& src_signals, uint32_t src_signal_expect_value) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); @@ -1282,7 +1443,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id num_topk, num_experts, rank, num_ranks, use_logfmt, workspace, num_device_sms, - launch_stream, phases, zero_copy); + launch_stream, phases, zero_copy, + overlap, src_signals.has_value() ? src_signals->data_ptr() : nullptr, src_signal_expect_value); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index aa62ccb0..302b62cd 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -20,6 +20,33 @@ #define TORCH_EXTENSION_NAME deep_ep_cpp #endif +namespace shared_memory { + +union MemHandleInner { + cudaIpcMemHandle_t cuda_ipc_mem_handle; + CUmemFabricHandle cu_mem_fabric_handle; +}; + +struct MemHandle { + MemHandleInner inner; + size_t size; +}; + +constexpr size_t HANDLE_SIZE = sizeof(MemHandle); + +class SharedMemoryAllocator { +public: + SharedMemoryAllocator(); + void malloc(void** ptr, size_t size); + void free(void* ptr); + void get_mem_handle(MemHandle* mem_handle, void* ptr); + void open_mem_handle(void** ptr, MemHandle* mem_handle); + void close_mem_handle(void* ptr); +private: + bool enable_fabric; +}; +} + namespace deep_ep { struct Buffer { @@ -44,7 +71,7 @@ struct Buffer { int num_device_sms; int rank, rdma_rank, nvl_rank; int num_ranks, num_rdma_ranks, num_nvl_ranks; - cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; + shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS]; // Stream for communication at::cuda::CUDAStream comm_stream; @@ -76,6 +103,8 @@ struct Buffer { volatile int* moe_recv_rdma_counter = nullptr; int* moe_recv_rdma_counter_mapped = nullptr; + shared_memory::SharedMemoryAllocator shared_memory_allocator; + public: Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy); @@ -147,8 +176,10 @@ struct Buffer { low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, const std::optional& cumulative_local_expert_recv_stats, const std::optional& dispatch_wait_recv_cost_stats, + const std::optional& x_global_scale, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool round_scale, bool use_ue8m0, + bool use_nvfp4, bool use_ue8m0_for_sf, bool async, bool return_recv_hook); std::tuple, std::optional>> @@ -157,7 +188,8 @@ struct Buffer { const std::optional& combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, - const std::optional& out = std::nullopt); + const std::optional& out = std::nullopt, + bool overlap = false, const std::optional& src_signals = std::nullopt, uint32_t src_signal_expect_value = 0); torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const; diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index d34775fd..a0eb0635 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -144,12 +144,14 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, + const float* x_global_scale, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, const void* x, const int64_t* topk_idx, int* next_clean, int num_next_clean_int, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, bool round_scale, bool use_ue8m0, + bool use_nvfp4, bool use_ue8m0_for_sf, void* workspace, int num_device_sms, cudaStream_t stream, int phases); @@ -163,7 +165,8 @@ void combine(void* combined_x, int num_topk, int num_experts, int rank, int num_ranks, bool use_logfmt, void* workspace, int num_device_sms, - cudaStream_t stream, int phases, bool zero_copy); + cudaStream_t stream, int phases, bool zero_copy, + bool overlap, uint32_t* src_signals, uint32_t src_signal_expect_value); } // namespace internode_ll diff --git a/csrc/kernels/exception.cuh b/csrc/kernels/exception.cuh index 7db0ddb7..3026374b 100644 --- a/csrc/kernels/exception.cuh +++ b/csrc/kernels/exception.cuh @@ -31,6 +31,18 @@ do { \ } while (0) #endif +#ifndef CU_CHECK +#define CU_CHECK(cmd) \ +do { \ + CUresult e = (cmd); \ + if (e != CUDA_SUCCESS) { \ + const char *error_str = NULL; \ + cuGetErrorString(e, &error_str); \ + throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \ + } \ +} while (0) +#endif + #ifndef EP_HOST_ASSERT #define EP_HOST_ASSERT(cond) \ do { \ diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 391a4b3d..23494e8f 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -36,13 +36,171 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, clean_0, num_clean_int_0, clean_1, num_clean_int_1); } -template + +struct PackedVec { + __nv_bfloat162 elts[4]; +}; + +using Type = __nv_bfloat16; + +__device__ __forceinline__ float exp2f_rcp(uint8_t exp) { + constexpr uint32_t FP32_EXPONENT_BIAS = 127; + return (exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(exp)); +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +// Convert 1 float value into one e2m1 value (represented as one uint8_t). +__device__ inline uint8_t float_to_e2m1(float f) { + // Get sign + uint8_t sign = (f < 0); + float abs_f = fabsf(f); + float abs_f_log2 = log2f(abs_f); + // map float to 2-bit exponent + uint8_t exp, mant; + if (abs_f_log2 < 0) { + exp = 0; + if (abs_f_log2 < -1) { + mant = 0; + } + else { + mant = 1; + } + } + else{ + exp = static_cast(floorf(abs_f_log2 + 1)); + exp = fminf(exp, 3.0f); + mant = (abs_f_log2 + 1 - exp > 0.5f) ? 1 : 0; + } + // Take one bit for mantissa + return (sign << 3) | (exp << 1) | mant; +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { + // PTX instructions used here requires sm100a. + #if CUDA_VERSION >= 12080 + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), + "f"(array[0].y), + "f"(array[1].x), + "f"(array[1].y), + "f"(array[2].x), + "f"(array[2].y), + "f"(array[3].x), + "f"(array[3].y)); + return val; + #else + uint32_t val = 0; + float2* data = reinterpret_cast(&array[0]); + for (int i = 0; i < 4; ++i) { + val |= (float_to_e2m1(data[i].x) & 0xFF) << (8 * i); + val |= (float_to_e2m1(data[i].y) & 0xFF) << (8 * i + 4); + } + return val; + #endif + #endif + } + +constexpr int CVT_ELTS_PER_THREAD = 8; +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_ELTS_PER_THREAD; + EP_STATIC_ASSERT(CVT_NUM_THREADS_PER_SF == 2 or CVT_NUM_THREADS_PER_SF == 4, "Invalid number of threads per SF"); + // Get the absolute maximum among all 16 values (two threads for 16, four threads for 32). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + if constexpr (CVT_NUM_THREADS_PER_SF == 4) { + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); + } + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // 8 bits representation of the SF. + uint8_t fp8SFVal; + float outputScale; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + __nv_fp8_e8m0 tmp; + // Scale the max value to the range of E2m1. + vecMax *= reciprocal_approximate_ftz(6.0f); + tmp.__x = __nv_cvt_float_to_e8m0(vecMax, __NV_SATFINITE, cudaRoundPosInf); + fp8SFVal = tmp.__x; + outputScale = exp2f_rcp(fp8SFVal); + } else { + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + auto SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + fp8SFVal = tmp.__x; + SFValue = static_cast(tmp); + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal)) * reciprocal(SFScaleVal)) + outputScale = SFValue != 0 + ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + } + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +} + +template __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, + const float* x_global_scale, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, const void* x, const int64_t* topk_idx, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, @@ -62,20 +220,28 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; // May extract UE8M0 from the scales - using scale_t = std::conditional_t; - using packed_t = std::conditional_t; + using scale_t = std::conditional_t; + using packed_t = std::conditional_t; EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length"); + EP_STATIC_ASSERT(!(kUseFP8 && kUseNVFP4), "FP8 and NVFP4 cannot be used together"); // FP8 staffs - constexpr int kNumPerChannels = 128; + constexpr int kNumPerChannels = kUseNVFP4 ? 16 : 128; const int num_scales = kHidden / kNumPerChannels; - const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); + constexpr size_t hidden_bytes = + kUseNVFP4 + ? kHidden * sizeof(__nv_fp8_storage_t) / 2 + : kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); const size_t hidden_int4 = hidden_bytes / sizeof(int4); // Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales // NOTES: currently we have 3 reserved int fields for future use - using vec_t = std::conditional_t; - const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); + using vec_t = std::conditional_t< + kUseNVFP4, + int32_t, + std::conditional_t>; + using rdma_x_scale_t = std::conditional_t; + const size_t num_bytes_per_msg = sizeof(int4) + ((kUseFP8 || kUseNVFP4) ? (hidden_bytes + num_scales * sizeof(rdma_x_scale_t)) : hidden_bytes); const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); @@ -101,13 +267,19 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, const auto x_int4 = static_cast(x) + token_idx * hidden_bf16_int4; const auto rdma_x_src_idx = reinterpret_cast(static_cast(rdma_x) + token_idx * num_bytes_per_msg); const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); - const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); + const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); // Overlap top-k index read and source token index writes auto dst_expert_idx = warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; + float SFScaleVal = 1.0f; + if constexpr (kUseNVFP4) { + // Get scaling value; + EP_DEVICE_ASSERT(x_global_scale != nullptr); + SFScaleVal = *(static_cast(x_global_scale)); + } - // FP8 cast + // FP8 or NVFP4 cast EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce"); #pragma unroll for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { @@ -141,6 +313,20 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); } rdma_x_vec[i] = int2_value; + } else if constexpr (kUseNVFP4) { + // Convert to NVFP4 + uint8_t sf_val; + PackedVec vec = *reinterpret_cast(&int4_value); + uint32_t result = cvt_warp_fp16_to_fp4(vec, SFScaleVal, &sf_val); + + // Write scale to send buffer + if (lane_id % 2 == 0){ + EP_DEVICE_ASSERT((i * kNumElemsPerRead) % kNumPerChannels == 0); + int rdma_x_scale_idx = i * kNumElemsPerRead / kNumPerChannels; + rdma_x_scales[rdma_x_scale_idx] = sf_val; + } + // Cast into send buffer + rdma_x_vec[i] = *reinterpret_cast(&result); } else { // Reinterpret-cast is for C++14 compatibility rdma_x_vec[i] = *reinterpret_cast(&int4_value); @@ -264,7 +450,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; const auto num_aligned_scales = align(num_scales, sizeof(float) / sizeof(scale_t)); - const auto recv_x_scales = static_cast(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; + const auto num_aligned_tokens = align(num_ranks * num_max_dispatch_tokens_per_rank, 128); + const auto recv_x_scales = static_cast(packed_recv_x_scales) + local_expert_idx * num_aligned_tokens * num_aligned_scales; // Shared between sub-warps in warp groups __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; @@ -294,7 +481,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; // Copy tokens - EP_DEVICE_ASSERT(num_scales <= 64); for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { // Copy source info const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); @@ -310,6 +496,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, // Copy scales if constexpr (kUseFP8) { + EP_DEVICE_ASSERT(num_scales <= 64); // Equivalent CuTe layout: // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1)) const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); @@ -329,6 +516,30 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id + 32)); recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; } + } else if constexpr (kUseNVFP4) { + // The physical layout is (l, rm, rk, 32, 4, 4) + const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); + const auto num_elems_per_pack = static_cast(sizeof(packed_t) / sizeof(scale_t)); + const auto token_idx = recv_token_begin_idx + i; + + const auto rk = align(kHidden / kNumPerChannels, 4) / 4; + const auto dim0_stride = rk * 128 * num_elems_per_pack; + const auto dim1_stride = 128 * num_elems_per_pack; + const auto dim2_stride = 4 * num_elems_per_pack; + const auto dim3_stride = num_elems_per_pack; + + const auto dim0_offset = token_idx / 128; + const auto dim2_offset = (token_idx % 128) % 32; + const auto dim3_offset = (token_idx % 128) / 32; + + #pragma unroll + for (int j = lane_id; j < num_scales; j += 32) { + const auto dim1_offset = j / num_elems_per_pack; + const auto dim4_offset = j % num_elems_per_pack; + auto scale = ld_nc_global(src_scales + j); + const auto offset = dim0_offset * dim0_stride + dim1_offset * dim1_stride + dim2_offset * dim2_stride + dim3_offset * dim3_stride + dim4_offset; + recv_x_scales[offset] = scale; + } } } } @@ -339,12 +550,14 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, + const float* x_global_scale, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, const void* x, const int64_t* topk_idx, int* next_clean, int num_next_clean_int, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, bool round_scale, bool use_ue8m0, + bool use_nvfp4, bool use_ue8m0_for_sf, void* workspace, int num_device_sms, cudaStream_t stream, int phases) { constexpr int kNumMaxTopK = 9; @@ -367,17 +580,22 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); #define DISPATCH_LAUNCH_CASE(hidden) { \ -auto dispatch_func = dispatch