@@ -206,6 +206,191 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
206206#undef LAUNCH_KERNEL
207207}
208208
209+ template <typename T, typename DST_DTYPE>
210+ __global__ void per_token_group_quant_8bit_packed_kernel (
211+ const T* __restrict__ input, void * __restrict__ output_q,
212+ unsigned int * __restrict__ output_s_packed, const int group_size,
213+ const int num_groups, const int groups_per_block, const int groups_per_row,
214+ const int mn, const int tma_aligned_mn, const float eps,
215+ const float min_8bit, const float max_8bit) {
216+ const int threads_per_group = 16 ;
217+ const int64_t local_group_id = threadIdx .x / threads_per_group;
218+ const int lane_id = threadIdx .x % threads_per_group;
219+
220+ const int64_t block_group_id = blockIdx .x * groups_per_block;
221+ const int64_t global_group_id = block_group_id + local_group_id;
222+ if (global_group_id >= num_groups) {
223+ return ;
224+ }
225+
226+ const int64_t block_group_offset = global_group_id * group_size;
227+
228+ float local_absmax = eps;
229+
230+ const T* group_input = input + block_group_offset;
231+ DST_DTYPE* group_output =
232+ static_cast <DST_DTYPE*>(output_q) + block_group_offset;
233+
234+ // shared memory to cache each group's data to avoid double DRAM reads.
235+ extern __shared__ __align__ (16 ) char smem_raw[];
236+ T* smem = reinterpret_cast <T*>(smem_raw);
237+ T* smem_group = smem + local_group_id * group_size;
238+
239+ constexpr int vec_size = 16 / sizeof (T);
240+ using vec_t = vllm::vec_n_t <T, vec_size>;
241+
242+ // copy global -> shared & compute absmax
243+ auto scalar_op_cache = [&] __device__ (T & dst, const T& src) {
244+ float abs_v = fabsf (static_cast <float >(src));
245+ local_absmax = fmaxf (local_absmax, abs_v);
246+ dst = src;
247+ };
248+
249+ vllm::vectorize_with_alignment<vec_size>(
250+ group_input, // in
251+ smem_group, // out (shared)
252+ group_size, // elements per group
253+ lane_id, // thread id
254+ threads_per_group, // stride in group
255+ scalar_op_cache); // scalar handler
256+
257+ local_absmax = GroupReduceMax (local_absmax);
258+
259+ float y_s = local_absmax / max_8bit;
260+ y_s = exp2f (ceilf (log2f (fmaxf (fabsf (y_s), 1e-10f ))));
261+
262+ // pack 4 scales into a uint32
263+ if (lane_id == 0 ) {
264+ // map flat group id to 2D indices (mn_idx, sf_k_idx)
265+ const int sf_k_idx = static_cast <int >(global_group_id % groups_per_row);
266+ const int mn_idx = static_cast <int >(global_group_id / groups_per_row);
267+
268+ if (mn_idx < mn) {
269+ // each uint32 in output_s_packed stores 4 packed scales
270+ const int sf_k_pack_idx = sf_k_idx / 4 ;
271+ const int pos = sf_k_idx % 4 ;
272+
273+ // reinterpret the UE8M0 scale y_s as IEEE bits, extract the 8-bit
274+ // exponent, and place it into the correct byte of the 32-bit word.
275+ const unsigned int bits = __float_as_uint (y_s);
276+ const unsigned int exponent = (bits >> 23u ) & 0xffu ;
277+ const unsigned int contrib = exponent << (pos * 8u );
278+
279+ const int out_idx = sf_k_pack_idx * tma_aligned_mn + mn_idx;
280+ // atomically OR 8-bit exponent into the packed scales buffer
281+ atomicOr (output_s_packed + out_idx, contrib);
282+ }
283+ }
284+
285+ __syncthreads ();
286+
287+ // quantize shared -> global 8-bit
288+ auto scalar_op_quant = [&] __device__ (DST_DTYPE & dst, const T& src) {
289+ float q = fminf (fmaxf (static_cast <float >(src) / y_s, min_8bit), max_8bit);
290+ dst = DST_DTYPE (q);
291+ };
292+
293+ vllm::vectorize_with_alignment<vec_size>(
294+ smem_group, // in (shared)
295+ group_output, // out (global quant tensor)
296+ group_size, // elements
297+ lane_id, // tid
298+ threads_per_group, // stride
299+ scalar_op_quant); // scalar handler
300+ }
301+
302+ void per_token_group_quant_8bit_packed (const torch::Tensor& input,
303+ torch::Tensor& output_q,
304+ torch::Tensor& output_s_packed,
305+ int64_t group_size, double eps,
306+ double min_8bit, double max_8bit) {
307+ TORCH_CHECK (input.is_contiguous ());
308+ TORCH_CHECK (output_q.is_contiguous ());
309+
310+ const int64_t k = input.size (-1 );
311+ TORCH_CHECK (k % group_size == 0 , " Last dimension (" , k,
312+ " ) must be divisible by group_size (" , group_size, " )." );
313+
314+ const int64_t mn = input.numel () / k;
315+ const int64_t groups_per_row = k / group_size;
316+ const int64_t num_groups = mn * groups_per_row;
317+
318+ TORCH_CHECK (output_s_packed.dim () == 2 ,
319+ " output_s_packed must be 2D, got dim=" , output_s_packed.dim (),
320+ " ." );
321+
322+ const int64_t k_num_packed_sfk = (groups_per_row + 3 ) / 4 ;
323+ const int64_t tma_aligned_mn = ((mn + 3 ) / 4 ) * 4 ;
324+
325+ TORCH_CHECK (output_s_packed.scalar_type () == at::ScalarType::Int,
326+ " output_s_packed must have dtype int32 for UE8M0-packed scales." );
327+ // DeepGEMM expects SFA scales in MN-major form with shape
328+ // [mn, ceil_div(K, 128 * 4)] and TMA-aligned stride on the last
329+ // dimension.
330+ TORCH_CHECK (output_s_packed.size (0 ) == mn &&
331+ output_s_packed.size (1 ) == k_num_packed_sfk,
332+ " output_s_packed shape must be [" , mn, " , " , k_num_packed_sfk,
333+ " ], but got [" , output_s_packed.size (0 ), " , " ,
334+ output_s_packed.size (1 ), " ]." );
335+
336+ cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
337+
338+ constexpr int THREADS_PER_GROUP = 16 ;
339+
340+ int groups_per_block = 1 ;
341+
342+ if (num_groups % 16 == 0 ) {
343+ groups_per_block = 16 ;
344+ } else if (num_groups % 8 == 0 ) {
345+ groups_per_block = 8 ;
346+ } else if (num_groups % 4 == 0 ) {
347+ groups_per_block = 4 ;
348+ } else if (num_groups % 2 == 0 ) {
349+ groups_per_block = 2 ;
350+ }
351+
352+ auto dst_type = output_q.scalar_type ();
353+ const int num_blocks = num_groups / groups_per_block;
354+ const int num_threads = groups_per_block * THREADS_PER_GROUP;
355+
356+ // zero-initialize packed scales, since we use atomicOr to accumulate
357+ // exponents from different groups.
358+ output_s_packed.zero_ ();
359+
360+ #define LAUNCH_PACKED_KERNEL (T, DST_DTYPE ) \
361+ do { \
362+ dim3 grid (num_blocks); \
363+ dim3 block (num_threads); \
364+ size_t smem_bytes = \
365+ static_cast <size_t >(groups_per_block) * group_size * sizeof (T); \
366+ per_token_group_quant_8bit_packed_kernel<T, DST_DTYPE> \
367+ <<<grid, block, smem_bytes, stream>>> ( \
368+ static_cast <const T*>(input.data_ptr ()), output_q.data_ptr (), \
369+ reinterpret_cast <unsigned int *>(output_s_packed.data_ptr ()), \
370+ static_cast <int >(group_size), static_cast <int >(num_groups), \
371+ groups_per_block, static_cast <int >(groups_per_row), \
372+ static_cast <int >(mn), static_cast <int >(tma_aligned_mn), \
373+ static_cast <float >(eps), static_cast <float >(min_8bit), \
374+ static_cast <float >(max_8bit)); \
375+ } while (0 )
376+
377+ VLLM_DISPATCH_FLOATING_TYPES (
378+ input.scalar_type (), " per_token_group_quant_8bit_packed" , ([&] {
379+ if (dst_type == at::ScalarType::Float8_e4m3fn) {
380+ LAUNCH_PACKED_KERNEL (scalar_t , __nv_fp8_e4m3);
381+ } else if (dst_type == at::ScalarType::Char) {
382+ LAUNCH_PACKED_KERNEL (scalar_t , int8_t );
383+ } else {
384+ TORCH_CHECK (
385+ false ,
386+ " per_token_group_quant_8bit_packed only supports FP8/INT8 "
387+ " outputs." );
388+ }
389+ }));
390+
391+ #undef LAUNCH_PACKED_KERNEL
392+ }
393+
209394void per_token_group_quant_fp8 (const torch::Tensor& input,
210395 torch::Tensor& output_q, torch::Tensor& output_s,
211396 int64_t group_size, double eps, double fp8_min,
0 commit comments