From 45a3008ceb0b9f55b23fdc3dc8d4f4be480b86aa Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 10 Nov 2025 08:02:17 +0000 Subject: [PATCH 1/2] feat: Integrate AITER bpreshuffle and ck operators on top of fp8 refactor Signed-off-by: vllmellm --- .../schemes/compressed_tensors_w8a8_fp8.py | 2 + .../kernels/scaled_mm/__init__.py | 4 + .../quantization/kernels/scaled_mm/aiter.py | 217 +++++++++++++++++- 3 files changed, 222 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 2cd29e0905d0..e25d2aaa439b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -192,6 +192,8 @@ def process_weights_after_loading(self, layer) -> None: if self.strategy == QuantizationStrategy.BLOCK: maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) + self.fp8_linear.process_weights_after_loading(layer) + def apply_weights( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index b033cc7905e4..b8c7f78aac64 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -8,6 +8,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( + AiterBpreshufflePerTokenFp8ScaledMMLinearKernel, + AiterCKPerTokenFp8ScaledMMLinearKernel, AiterScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( @@ -64,6 +66,8 @@ ChannelWiseTorchScaledMMLinearKernel, ], PlatformEnum.ROCM: [ + AiterBpreshufflePerTokenFp8ScaledMMLinearKernel, + AiterCKPerTokenFp8ScaledMMLinearKernel, ROCmScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, RowWiseTorchScaledMMLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 3ac90553bbc7..430e407156c5 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -2,15 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch +from aiter.ops.shuffle import shuffle_weight import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op from .cutlass import CutlassScaledMMLinearKernel -from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, + Int8ScaledMMLinearLayerConfig, +) + +logger = init_logger(__name__) def rocm_aiter_gemm_w8a8_impl( @@ -52,6 +62,54 @@ def rocm_aiter_gemm_w8a8_fake( ) +# bpshuffle +def rocm_aiter_gemm_a8w8_bpreshuffle_impl( + input: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype | None = None, + scale_a: torch.Tensor | None = None, + scale_b: torch.Tensor | None = None, +) -> torch.Tensor: + # This AITER function can be used for + # - per-token activations + per-channel weights + # e.g. vllm/model_executor/layers/quantization/utils/w8a8_utils.py + # accept the weight as # keep the weight as (N, K) + # NOTE: The weight has to be shuffled in the + # process_weights_after_loading of the CompressedTensorsW8A8Fp8 class + + from aiter import gemm_a8w8_bpreshuffle_ck + + m = input.shape[0] + n = weight.shape[0] + Y = torch.empty(m, n, dtype=out_dtype, device=input.device) + gemm_a8w8_bpreshuffle_ck(input, weight, scale_a, scale_b, Y) + return Y + + +def rocm_aiter_gemm_a8w8_bpreshuffle_fake( + input: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype | None = None, + scale_a: torch.Tensor | None = None, + scale_b: torch.Tensor | None = None, +) -> torch.Tensor: + m = input.shape[0] + n = weight.shape[0] + if out_dtype is None: + out_dtype = input.dtype + return torch.empty((m, n), dtype=out_dtype, device=input.device) + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_gemm_a8w8_bpreshuffle", + op_func=rocm_aiter_gemm_a8w8_bpreshuffle_impl, + mutates_args=[], + fake_impl=rocm_aiter_gemm_a8w8_bpreshuffle_fake, + dispatch_key=current_platform.dispatch_key, + ) + + class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -157,3 +215,160 @@ def apply_weights( return torch.ops.vllm.rocm_aiter_gemm_w8a8( x_q, w_q.t(), x_s, w_s, bias, out_dtype ) + + +# bpreshuffle +class AiterBpreshufflePerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + def get_ouput_padding(self) -> int | None: + # PTPC kernels do not require padding. + return None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return (False, "AITER bpreshuffle is ROCm-only") + if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER): + return (False, "AITER bpreshuffle is disabled by env var") + try: + import aiter # noqa: F401 + except Exception: + return (False, "AITER not installed") + + # Check if the configuration is PTPC + is_per_channel_weight = c.weight_quant_key.scale.group_shape.is_per_token() + is_per_token_activation = ( + c.activation_quant_key.scale.group_shape.is_per_token() + ) + is_ptpc = is_per_channel_weight and is_per_token_activation + + logger.info_once(f"AiterBpreshuffle: can_implement called. is_ptpc={is_ptpc}") + + if not is_ptpc: + return (False, "This kernel only handles Per-Token/Per-Channel (PTPC)") + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + logger.info_once("AiterBpreshuffle: SHUFFLING WEIGHTS NOW.") + + w_q, _, _, _ = self._get_layer_params(layer) + + N = w_q.shape[1] + K = w_q.shape[0] + + if N % 16 == 0 and K % 16 == 0: + # AITER shuffle_weight expectation [N, K] + w_q_nk = w_q.t().contiguous() + + # Execute shuffle + shuffled_w_nk = shuffle_weight(w_q_nk, layout=(16, 16)) + + del layer.weight + layer.register_buffer("weight", shuffled_w_nk) + + logger.info_once("[AiterBpreshuffle: Weight shuffle COMPLETE.") + + else: + raise ValueError( + f"Weight shape (N={N}, K={K}) not divisible by 16 " + "for AITER bpreshuffle." + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + # 1. Obtain parameters + w_q, w_s, x_s, x_s_ub = self._get_layer_params(layer) + # 2. Dynamic quantization input + qinput, qinput_scale = self.quant_fp8(x, x_s, x_s_ub) + + logger.info_once( + "AiterBpreshuffle: apply_weights... ABOUT TO CALL C++ KERNEL..." + ) + + # 3. Call the AITER bpreshuffle CK operator. + output = torch.ops.vllm.rocm_aiter_gemm_a8w8_bpreshuffle( + qinput, + w_q, # Input [N, K] shuffle weights + out_dtype=self.config.out_dtype, + scale_a=qinput_scale, + scale_b=w_s, + ) + + logger.info_once("AiterBpreshuffle: C++ KERNEL CALL SUCCEEDED.") + + if bias is not None: + output.add_(bias) + return output + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return rocm_aiter_gemm_a8w8_bpreshuffle_impl + + +# AITER FP8 CK +class AiterCKPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + """ + AITER PTPC kernel (gemm_a8w8_CK) without pre-shuffling. + """ + + def get_ouput_padding(self) -> int | None: + return None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return (False, "AITER CK is ROCm-only") + if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER): + return (False, "AITER CK is disabled by env var") + try: + import aiter # noqa: F401 + except Exception: + return (False, "AITER not installed") + + is_per_channel_weight = c.weight_quant_key.scale.group_shape.is_per_token() + is_per_token_activation = ( + c.activation_quant_key.scale.group_shape.is_per_token() + ) + is_ptpc = is_per_channel_weight and is_per_token_activation + + logger.info_once(f"AiterCK: can_implement called. is_ptpc={is_ptpc}") + + if not is_ptpc: + return (False, "This kernel only handles Per-Token/Per-Channel (PTPC)") + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + logger.info_once( + "AITER CK: process_weights_after_loading... DOING NOTHING (pass)." + ) + pass + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + w_q, w_s, x_s, x_s_ub = self._get_layer_params(layer) + + qinput, qinput_scale = self.quant_fp8(x, x_s, x_s_ub) + + logger.info_once( + "AiterCK: apply_weights... " + "ABOUT TO CALL C++ KERNEL (this is where it hangs)..." + ) + + output = torch.ops.vllm.rocm_aiter_gemm_w8a8( + qinput, w_q.t(), qinput_scale, w_s, bias, self.config.out_dtype + ) + + logger.info_once("AiterCK: C++ KERNEL CALL SUCCEEDED.") + + return output + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return rocm_aiter_gemm_w8a8_impl From 679a7cffdc3baf2a2f205d993a60a8925ebfd358 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 17 Nov 2025 06:29:29 +0000 Subject: [PATCH 2/2] WIP: Integrate Aiter bpreshuffle and ck kernels Signed-off-by: vllmellm --- vllm/_aiter_ops.py | 56 ++++++ .../kernels/scaled_mm/__init__.py | 4 + .../quantization/kernels/scaled_mm/aiter.py | 165 ++++++++++++++++++ 3 files changed, 225 insertions(+) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 5508e59bcd2f..6de21176e948 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -402,6 +402,42 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( return torch.empty_like(x), torch.empty_like(residual) +def _rocm_aiter_gemm_a8w8_bpreshuffle_impl( + input: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype | None = None, + scale_a: torch.Tensor | None = None, + scale_b: torch.Tensor | None = None, +) -> torch.Tensor: + # This AITER function can be used for + # - per-token activations + per-channel weights + # accept the weight as # keep the weight as (N, K) + # NOTE: The weight has to be shuffled in the + # process_weights_after_loading of the CompressedTensorsW8A8Fp8 class + + from aiter import gemm_a8w8_bpreshuffle_ck + + m = input.shape[0] + n = weight.shape[0] + Y = torch.empty(m, n, dtype=out_dtype, device=input.device) + gemm_a8w8_bpreshuffle_ck(input, weight, scale_a, scale_b, Y) + return Y + + +def _rocm_aiter_gemm_a8w8_bpreshuffle_fake( + input: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype | None = None, + scale_a: torch.Tensor | None = None, + scale_b: torch.Tensor | None = None, +) -> torch.Tensor: + m = input.shape[0] + n = weight.shape[0] + if out_dtype is None: + out_dtype = input.dtype + return torch.empty((m, n), dtype=out_dtype, device=input.device) + + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False @@ -592,6 +628,14 @@ def register_ops_once() -> None: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_gemm_a8w8_bpreshuffle", + op_func=_rocm_aiter_gemm_a8w8_bpreshuffle_impl, + mutates_args=[], + fake_impl=_rocm_aiter_gemm_a8w8_bpreshuffle_fake, + dispatch_key=current_platform.dispatch_key, + ) + _OPS_REGISTERED = True @staticmethod @@ -635,6 +679,18 @@ def gemm_a8w8_blockscale( A, B, As, Bs, output_dtype ) + @staticmethod + def gemm_a8w8_bpreshuffle( + input: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype | None = None, + scale_a: torch.Tensor | None = None, + scale_b: torch.Tensor | None = None, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_gemm_a8w8_bpreshuffle( + input, weight, out_dtype, scale_a, scale_b + ) + @staticmethod def fused_moe( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 36e4a16c0168..90cbda90adf9 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -8,6 +8,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( + AiterBpreshufflePerTokenFp8ScaledMMLinearKernel, + AiterCKPerTokenFp8ScaledMMLinearKernel, AiterScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( @@ -64,6 +66,8 @@ ChannelWiseTorchScaledMMLinearKernel, ], PlatformEnum.ROCM: [ + AiterBpreshufflePerTokenFp8ScaledMMLinearKernel, + AiterCKPerTokenFp8ScaledMMLinearKernel, ROCmScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, RowWiseTorchScaledMMLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 4a1c76ffd9b1..28c5640d319a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -2,17 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + import torch +from aiter.ops.shuffle import shuffle_weight from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops +from vllm.logger import init_logger from vllm.platforms import current_platform from .cutlass import CutlassScaledMMLinearKernel from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, Int8ScaledMMLinearLayerConfig, ) +logger = init_logger(__name__) + class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod @@ -117,3 +125,160 @@ def apply_weights( # b to be [N, K] # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) + + +class AiterBpreshufflePerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + def get_ouput_padding(self) -> int | None: + # PTPC kernels do not require padding. + return None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return (False, "AITER bpreshuffle is ROCm-only") + + if not rocm_aiter_ops.is_linear_enabled(): + return (False, "AITER bpreshuffle is disabled by env var") + + try: + import aiter # noqa: F401 + except Exception: + return (False, "AITER not installed") + + # Check if the configuration is PTPC + is_per_channel_weight = c.weight_quant_key.scale.group_shape.is_per_token() + is_per_token_activation = ( + c.activation_quant_key.scale.group_shape.is_per_token() + ) + is_ptpc = is_per_channel_weight and is_per_token_activation + + logger.info_once(f"AiterBpreshuffle: can_implement called. is_ptpc={is_ptpc}") + + if not is_ptpc: + return (False, "This kernel only handles Per-Token/Per-Channel (PTPC)") + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + logger.info_once("AiterBpreshuffle: SHUFFLING WEIGHTS NOW.") + + w_q, _, _, _ = self._get_layer_params(layer) + + N = w_q.shape[1] + K = w_q.shape[0] + + if N % 16 == 0 and K % 16 == 0: + # AITER shuffle_weight expectation [N, K] + w_q_nk = w_q.t().contiguous() + + # Execute shuffle + shuffled_w_nk = shuffle_weight(w_q_nk, layout=(16, 16)) + + del layer.weight + layer.register_buffer("weight", shuffled_w_nk) + + logger.info_once("[AiterBpreshuffle: Weight shuffle COMPLETE.") + + else: + raise ValueError( + f"Weight shape (N={N}, K={K}) not divisible by 16 " + "for AITER bpreshuffle." + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + # 1. Obtain parameters + w_q, w_s, x_s, x_s_ub = self._get_layer_params(layer) + # 2. Dynamic quantization input + qinput, qinput_scale = self.quant_fp8(x, x_s, x_s_ub) + + logger.info_once( + "AiterBpreshuffle: apply_weights... ABOUT TO CALL C++ KERNEL..." + ) + + output = rocm_aiter_ops.gemm_a8w8_bpreshuffle( + qinput, + w_q, # Input [N, K] shuffle weights + out_dtype=self.config.out_dtype, + scale_a=qinput_scale, + scale_b=w_s, + ) + + logger.info_once("AiterBpreshuffle: C++ KERNEL CALL SUCCEEDED.") + + if bias is not None: + output.add_(bias) + return output + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return rocm_aiter_ops.gemm_a8w8_bpreshuffle + + +class AiterCKPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + """ + AITER PTPC kernel (gemm_a8w8_CK) without pre-shuffling. + """ + + def get_ouput_padding(self) -> int | None: + return None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_rocm(): + return (False, "AITER CK is ROCm-only") + + if not rocm_aiter_ops.is_linear_enabled(): + return (False, "AITER CK is disabled by env var") + + try: + import aiter # noqa: F401 + except Exception: + return (False, "AITER not installed") + + is_per_channel_weight = c.weight_quant_key.scale.group_shape.is_per_token() + is_per_token_activation = ( + c.activation_quant_key.scale.group_shape.is_per_token() + ) + is_ptpc = is_per_channel_weight and is_per_token_activation + + logger.info_once(f"AiterCK: can_implement called. is_ptpc={is_ptpc}") + + if not is_ptpc: + return (False, "This kernel only handles Per-Token/Per-Channel (PTPC)") + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + logger.info_once( + "AITER CK: process_weights_after_loading... DOING NOTHING (pass)." + ) + pass + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + w_q, w_s, x_s, x_s_ub = self._get_layer_params(layer) + + qinput, qinput_scale = self.quant_fp8(x, x_s, x_s_ub) + + logger.info_once( + "AiterCK: apply_weights... " + "ABOUT TO CALL C++ KERNEL (this is where it hangs)..." + ) + + output = rocm_aiter_ops.gemm_a8w8( + qinput, w_q.t(), qinput_scale, w_s, bias, self.config.out_dtype + ) + + logger.info_once("AiterCK: C++ KERNEL CALL SUCCEEDED.") + return output + + def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: + return rocm_aiter_ops.gemm_a8w8