Skip to content

Commit 541a2ef

Browse files
authored
[Perf] Deepgemm fused layout kernel for activations, 4.3% throughput improvement, 10.7% TTFT improvement. (#29546)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent b0f4866 commit 541a2ef

File tree

5 files changed

+311
-12
lines changed

5 files changed

+311
-12
lines changed

csrc/ops.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,14 @@ void per_token_group_quant_int8(const torch::Tensor& input,
299299
torch::Tensor& output_q,
300300
torch::Tensor& output_s, int64_t group_size,
301301
double eps, double int8_min, double int8_max);
302+
303+
// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales.
304+
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
305+
torch::Tensor& output_q,
306+
torch::Tensor& output_s_packed,
307+
int64_t group_size, double eps,
308+
double min_8bit, double max_8bit);
309+
302310
#endif
303311

304312
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,

csrc/quantization/w8a8/fp8/per_token_group_quant.cu

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,191 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
206206
#undef LAUNCH_KERNEL
207207
}
208208

209+
template <typename T, typename DST_DTYPE>
210+
__global__ void per_token_group_quant_8bit_packed_kernel(
211+
const T* __restrict__ input, void* __restrict__ output_q,
212+
unsigned int* __restrict__ output_s_packed, const int group_size,
213+
const int num_groups, const int groups_per_block, const int groups_per_row,
214+
const int mn, const int tma_aligned_mn, const float eps,
215+
const float min_8bit, const float max_8bit) {
216+
const int threads_per_group = 16;
217+
const int64_t local_group_id = threadIdx.x / threads_per_group;
218+
const int lane_id = threadIdx.x % threads_per_group;
219+
220+
const int64_t block_group_id = blockIdx.x * groups_per_block;
221+
const int64_t global_group_id = block_group_id + local_group_id;
222+
if (global_group_id >= num_groups) {
223+
return;
224+
}
225+
226+
const int64_t block_group_offset = global_group_id * group_size;
227+
228+
float local_absmax = eps;
229+
230+
const T* group_input = input + block_group_offset;
231+
DST_DTYPE* group_output =
232+
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
233+
234+
// shared memory to cache each group's data to avoid double DRAM reads.
235+
extern __shared__ __align__(16) char smem_raw[];
236+
T* smem = reinterpret_cast<T*>(smem_raw);
237+
T* smem_group = smem + local_group_id * group_size;
238+
239+
constexpr int vec_size = 16 / sizeof(T);
240+
using vec_t = vllm::vec_n_t<T, vec_size>;
241+
242+
// copy global -> shared & compute absmax
243+
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
244+
float abs_v = fabsf(static_cast<float>(src));
245+
local_absmax = fmaxf(local_absmax, abs_v);
246+
dst = src;
247+
};
248+
249+
vllm::vectorize_with_alignment<vec_size>(
250+
group_input, // in
251+
smem_group, // out (shared)
252+
group_size, // elements per group
253+
lane_id, // thread id
254+
threads_per_group, // stride in group
255+
scalar_op_cache); // scalar handler
256+
257+
local_absmax = GroupReduceMax(local_absmax);
258+
259+
float y_s = local_absmax / max_8bit;
260+
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
261+
262+
// pack 4 scales into a uint32
263+
if (lane_id == 0) {
264+
// map flat group id to 2D indices (mn_idx, sf_k_idx)
265+
const int sf_k_idx = static_cast<int>(global_group_id % groups_per_row);
266+
const int mn_idx = static_cast<int>(global_group_id / groups_per_row);
267+
268+
if (mn_idx < mn) {
269+
// each uint32 in output_s_packed stores 4 packed scales
270+
const int sf_k_pack_idx = sf_k_idx / 4;
271+
const int pos = sf_k_idx % 4;
272+
273+
// reinterpret the UE8M0 scale y_s as IEEE bits, extract the 8-bit
274+
// exponent, and place it into the correct byte of the 32-bit word.
275+
const unsigned int bits = __float_as_uint(y_s);
276+
const unsigned int exponent = (bits >> 23u) & 0xffu;
277+
const unsigned int contrib = exponent << (pos * 8u);
278+
279+
const int out_idx = sf_k_pack_idx * tma_aligned_mn + mn_idx;
280+
// atomically OR 8-bit exponent into the packed scales buffer
281+
atomicOr(output_s_packed + out_idx, contrib);
282+
}
283+
}
284+
285+
__syncthreads();
286+
287+
// quantize shared -> global 8-bit
288+
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
289+
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
290+
dst = DST_DTYPE(q);
291+
};
292+
293+
vllm::vectorize_with_alignment<vec_size>(
294+
smem_group, // in (shared)
295+
group_output, // out (global quant tensor)
296+
group_size, // elements
297+
lane_id, // tid
298+
threads_per_group, // stride
299+
scalar_op_quant); // scalar handler
300+
}
301+
302+
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
303+
torch::Tensor& output_q,
304+
torch::Tensor& output_s_packed,
305+
int64_t group_size, double eps,
306+
double min_8bit, double max_8bit) {
307+
TORCH_CHECK(input.is_contiguous());
308+
TORCH_CHECK(output_q.is_contiguous());
309+
310+
const int64_t k = input.size(-1);
311+
TORCH_CHECK(k % group_size == 0, "Last dimension (", k,
312+
") must be divisible by group_size (", group_size, ").");
313+
314+
const int64_t mn = input.numel() / k;
315+
const int64_t groups_per_row = k / group_size;
316+
const int64_t num_groups = mn * groups_per_row;
317+
318+
TORCH_CHECK(output_s_packed.dim() == 2,
319+
"output_s_packed must be 2D, got dim=", output_s_packed.dim(),
320+
".");
321+
322+
const int64_t k_num_packed_sfk = (groups_per_row + 3) / 4;
323+
const int64_t tma_aligned_mn = ((mn + 3) / 4) * 4;
324+
325+
TORCH_CHECK(output_s_packed.scalar_type() == at::ScalarType::Int,
326+
"output_s_packed must have dtype int32 for UE8M0-packed scales.");
327+
// DeepGEMM expects SFA scales in MN-major form with shape
328+
// [mn, ceil_div(K, 128 * 4)] and TMA-aligned stride on the last
329+
// dimension.
330+
TORCH_CHECK(output_s_packed.size(0) == mn &&
331+
output_s_packed.size(1) == k_num_packed_sfk,
332+
"output_s_packed shape must be [", mn, ", ", k_num_packed_sfk,
333+
"], but got [", output_s_packed.size(0), ", ",
334+
output_s_packed.size(1), "].");
335+
336+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
337+
338+
constexpr int THREADS_PER_GROUP = 16;
339+
340+
int groups_per_block = 1;
341+
342+
if (num_groups % 16 == 0) {
343+
groups_per_block = 16;
344+
} else if (num_groups % 8 == 0) {
345+
groups_per_block = 8;
346+
} else if (num_groups % 4 == 0) {
347+
groups_per_block = 4;
348+
} else if (num_groups % 2 == 0) {
349+
groups_per_block = 2;
350+
}
351+
352+
auto dst_type = output_q.scalar_type();
353+
const int num_blocks = num_groups / groups_per_block;
354+
const int num_threads = groups_per_block * THREADS_PER_GROUP;
355+
356+
// zero-initialize packed scales, since we use atomicOr to accumulate
357+
// exponents from different groups.
358+
output_s_packed.zero_();
359+
360+
#define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \
361+
do { \
362+
dim3 grid(num_blocks); \
363+
dim3 block(num_threads); \
364+
size_t smem_bytes = \
365+
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
366+
per_token_group_quant_8bit_packed_kernel<T, DST_DTYPE> \
367+
<<<grid, block, smem_bytes, stream>>>( \
368+
static_cast<const T*>(input.data_ptr()), output_q.data_ptr(), \
369+
reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()), \
370+
static_cast<int>(group_size), static_cast<int>(num_groups), \
371+
groups_per_block, static_cast<int>(groups_per_row), \
372+
static_cast<int>(mn), static_cast<int>(tma_aligned_mn), \
373+
static_cast<float>(eps), static_cast<float>(min_8bit), \
374+
static_cast<float>(max_8bit)); \
375+
} while (0)
376+
377+
VLLM_DISPATCH_FLOATING_TYPES(
378+
input.scalar_type(), "per_token_group_quant_8bit_packed", ([&] {
379+
if (dst_type == at::ScalarType::Float8_e4m3fn) {
380+
LAUNCH_PACKED_KERNEL(scalar_t, __nv_fp8_e4m3);
381+
} else if (dst_type == at::ScalarType::Char) {
382+
LAUNCH_PACKED_KERNEL(scalar_t, int8_t);
383+
} else {
384+
TORCH_CHECK(
385+
false,
386+
"per_token_group_quant_8bit_packed only supports FP8/INT8 "
387+
"outputs.");
388+
}
389+
}));
390+
391+
#undef LAUNCH_PACKED_KERNEL
392+
}
393+
209394
void per_token_group_quant_fp8(const torch::Tensor& input,
210395
torch::Tensor& output_q, torch::Tensor& output_s,
211396
int64_t group_size, double eps, double fp8_min,

csrc/torch_bindings.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
617617
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
618618
&per_token_group_quant_fp8);
619619

620+
// Compute per-token-group 8-bit quantized tensor and UE8M0-packed,
621+
// TMA-aligned scales for DeepGEMM.
622+
ops.def(
623+
"per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, "
624+
"Tensor! output_s_packed, int group_size, float eps, float fp8_min, "
625+
"float fp8_max) -> ()");
626+
ops.impl("per_token_group_fp8_quant_packed", torch::kCUDA,
627+
&per_token_group_quant_8bit_packed);
628+
620629
// Compute per-token-group INT8 quantized tensor and scaling factor.
621630
ops.def(
622631
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
2424
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
2525
per_token_group_quant_fp8,
26+
per_token_group_quant_fp8_packed_for_deepgemm,
2627
silu_mul_per_token_group_quant_fp8_colmajor,
2728
)
2829
from vllm.utils.deep_gemm import (
30+
DeepGemmQuantScaleFMT,
2931
get_mk_alignment_for_contiguous_layout,
3032
m_grouped_fp8_gemm_nt_contiguous,
3133
)
@@ -157,23 +159,40 @@ def workspace_shapes(
157159
def _act_mul_quant(
158160
self, input: torch.Tensor, output: torch.Tensor, activation: str
159161
) -> tuple[torch.Tensor, torch.Tensor]:
160-
if activation == "silu":
161-
return silu_mul_per_token_group_quant_fp8_colmajor(
162-
input=input, output=output
163-
)
164-
else:
165-
# This is a fallback path. If we find ourselves using any activation other
166-
# than silu, we should add that activation to
167-
# silu_mul_per_token_group_quant_fp8_colmajor kernel as it is much faster.
162+
assert self.block_shape is not None
163+
block_k = self.block_shape[1]
164+
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
165+
166+
# 1. DeepGemm UE8M0: use packed per-token-group quant
167+
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
168168
M_sum, N = input.size()
169169
act_out = torch.empty(
170170
(M_sum, N // 2), dtype=input.dtype, device=input.device
171171
)
172172
self.activation(activation, act_out, input)
173-
assert self.block_shape is not None
174-
return per_token_group_quant_fp8(
175-
act_out, self.block_shape[1], column_major_scales=True, out_q=output
173+
a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm(
174+
act_out,
175+
block_k,
176+
out_q=output,
176177
)
178+
return a2q, a2q_scale
179+
180+
# 2. Hopper / non‑E8M0: prefer the fused SiLU+mul+quant kernel
181+
if activation == "silu":
182+
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
183+
return silu_mul_per_token_group_quant_fp8_colmajor(
184+
input=input,
185+
output=output,
186+
use_ue8m0=use_ue8m0,
187+
)
188+
189+
# 3. fallback path for non-SiLU activations in non‑UE8M0 cases.
190+
M_sum, N = input.size()
191+
act_out = torch.empty((M_sum, N // 2), dtype=input.dtype, device=input.device)
192+
self.activation(activation, act_out, input)
193+
return per_token_group_quant_fp8(
194+
act_out, block_k, column_major_scales=True, out_q=output
195+
)
177196

178197
def apply(
179198
self,

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,11 @@ def _run_deepgemm(
269269
weight_scale: torch.Tensor,
270270
) -> torch.Tensor:
271271
assert self.deepgemm_input_quant_op is not None
272-
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
272+
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
273+
input_2d,
274+
group_size=self.act_quant_group_shape.col,
275+
use_ue8m0=True,
276+
)
273277
output = torch.empty(
274278
(q_input.shape[0], weight.shape[0]),
275279
dtype=torch.bfloat16,
@@ -791,6 +795,80 @@ def per_token_group_quant_fp8(
791795
return x_q, x_s
792796

793797

798+
def per_token_group_quant_fp8_packed_for_deepgemm(
799+
x: torch.Tensor,
800+
group_size: int,
801+
eps: float = 1e-10,
802+
use_ue8m0: bool | None = None,
803+
out_q: torch.Tensor | None = None,
804+
) -> tuple[torch.Tensor, torch.Tensor]:
805+
"""FP8 per-token-group quantization for DeepGEMM.
806+
807+
Returns:
808+
(x_q, x_s_packed)
809+
x_q: FP8 activations, same shape as `x`.
810+
x_s_packed: Int32 tensor with logical shape
811+
[mn, ceil(num_groups_per_row / 4)], laid out with
812+
TMA-aligned stride along the packed-K dimension
813+
"""
814+
if use_ue8m0 is None:
815+
use_ue8m0 = is_deep_gemm_e8m0_used()
816+
# for DeepGEMM UE8M0-packed layout we *require* UE8M0 scales.
817+
assert use_ue8m0, (
818+
"per_token_group_quant_fp8_packed_for_deepgemm requires UE8M0 scales."
819+
)
820+
821+
dtype = current_platform.fp8_dtype()
822+
assert x.shape[-1] % group_size == 0, (
823+
f"the last dimension of `x` {x.shape[-1]} must be divisible "
824+
f"by `group_size` {group_size}"
825+
)
826+
assert x.stride(-1) == 1, "`x` groups must be contiguous"
827+
828+
finfo = torch.finfo(dtype)
829+
fp8_min, fp8_max = finfo.min, finfo.max
830+
831+
# compute DeepGEMM-style packed scale tensor shape.
832+
hidden_dim = x.shape[-1]
833+
mn = x.numel() // hidden_dim
834+
num_groups_per_row = hidden_dim // group_size
835+
k_num_packed_sf_k = (num_groups_per_row + 3) // 4
836+
tma_aligned_mn = ((mn + 3) // 4) * 4
837+
838+
x_s_packed = torch.empty_strided(
839+
(mn, k_num_packed_sf_k),
840+
(1, tma_aligned_mn),
841+
device=x.device,
842+
dtype=torch.int32,
843+
)
844+
845+
# CUDA kernel path only (DeepGEMM + E8M0 is CUDA-specific).
846+
assert current_platform.is_cuda(), (
847+
"per_token_group_quant_fp8_packed_for_deepgemm is only valid on CUDA "
848+
"platforms using DeepGEMM."
849+
)
850+
851+
x_contiguous = x.contiguous()
852+
if out_q is not None:
853+
x_q_local = out_q
854+
else:
855+
x_q_local = torch.empty_like(x_contiguous, device=x.device, dtype=dtype)
856+
857+
torch.ops._C.per_token_group_fp8_quant_packed(
858+
x_contiguous,
859+
x_q_local,
860+
x_s_packed,
861+
group_size,
862+
eps,
863+
fp8_min,
864+
fp8_max,
865+
)
866+
867+
# return a tensor with the original logical shape.
868+
x_q = x_q_local.view_as(x)
869+
return x_q, x_s_packed
870+
871+
794872
@triton.jit
795873
def _w8a8_triton_block_scaled_mm(
796874
# Pointers to inputs and output

0 commit comments

Comments
 (0)