Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,42 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
return torch.empty_like(x), torch.empty_like(residual)


def _rocm_aiter_gemm_a8w8_bpreshuffle_impl(
input: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype | None = None,
scale_a: torch.Tensor | None = None,
scale_b: torch.Tensor | None = None,
) -> torch.Tensor:
# This AITER function can be used for
# - per-token activations + per-channel weights
# accept the weight as # keep the weight as (N, K)
# NOTE: The weight has to be shuffled in the
# process_weights_after_loading of the CompressedTensorsW8A8Fp8 class

from aiter import gemm_a8w8_bpreshuffle_ck

m = input.shape[0]
n = weight.shape[0]
Y = torch.empty(m, n, dtype=out_dtype, device=input.device)
gemm_a8w8_bpreshuffle_ck(input, weight, scale_a, scale_b, Y)
return Y


def _rocm_aiter_gemm_a8w8_bpreshuffle_fake(
input: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype | None = None,
scale_a: torch.Tensor | None = None,
scale_b: torch.Tensor | None = None,
) -> torch.Tensor:
m = input.shape[0]
n = weight.shape[0]
if out_dtype is None:
out_dtype = input.dtype
return torch.empty((m, n), dtype=out_dtype, device=input.device)


# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False

Expand Down Expand Up @@ -592,6 +628,14 @@ def register_ops_once() -> None:
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_gemm_a8w8_bpreshuffle",
op_func=_rocm_aiter_gemm_a8w8_bpreshuffle_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_a8w8_bpreshuffle_fake,
dispatch_key=current_platform.dispatch_key,
)

_OPS_REGISTERED = True

@staticmethod
Expand Down Expand Up @@ -635,6 +679,18 @@ def gemm_a8w8_blockscale(
A, B, As, Bs, output_dtype
)

@staticmethod
def gemm_a8w8_bpreshuffle(
input: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype | None = None,
scale_a: torch.Tensor | None = None,
scale_b: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_gemm_a8w8_bpreshuffle(
input, weight, out_dtype, scale_a, scale_b
)

@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.BLOCK:
maybe_post_process_fp8_weight_block(layer)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if fp8_linear is initialised.

self.fp8_linear.process_weights_after_loading(layer)

def apply_weights(
self,
layer: torch.nn.Module,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterBpreshufflePerTokenFp8ScaledMMLinearKernel,
AiterCKPerTokenFp8ScaledMMLinearKernel,
AiterScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
Expand Down Expand Up @@ -64,6 +66,8 @@
ChannelWiseTorchScaledMMLinearKernel,
],
PlatformEnum.ROCM: [
AiterBpreshufflePerTokenFp8ScaledMMLinearKernel,
AiterCKPerTokenFp8ScaledMMLinearKernel,
ROCmScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel,
RowWiseTorchScaledMMLinearKernel,
Expand Down
169 changes: 168 additions & 1 deletion vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,24 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


from collections.abc import Callable

import torch
from aiter.ops.shuffle import shuffle_weight

from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.platforms import current_platform

from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearLayerConfig,
)

logger = init_logger(__name__)


class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
Expand Down Expand Up @@ -115,3 +125,160 @@ def apply_weights(
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)


class AiterBpreshufflePerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None:
# PTPC kernels do not require padding.
return None

@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return (False, "AITER bpreshuffle is ROCm-only")

if not rocm_aiter_ops.is_linear_enabled():
return (False, "AITER bpreshuffle is disabled by env var")

try:
import aiter # noqa: F401
except Exception:
return (False, "AITER not installed")

# Check if the configuration is PTPC
is_per_channel_weight = c.weight_quant_key.scale.group_shape.is_per_token()
is_per_token_activation = (
c.activation_quant_key.scale.group_shape.is_per_token()
)
is_ptpc = is_per_channel_weight and is_per_token_activation

logger.info_once(f"AiterBpreshuffle: can_implement called. is_ptpc={is_ptpc}")

if not is_ptpc:
return (False, "This kernel only handles Per-Token/Per-Channel (PTPC)")

return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
logger.info_once("AiterBpreshuffle: SHUFFLING WEIGHTS NOW.")

w_q, _, _, _ = self._get_layer_params(layer)

N = w_q.shape[1]
K = w_q.shape[0]

if N % 16 == 0 and K % 16 == 0:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# AITER shuffle_weight expectation [N, K]
w_q_nk = w_q.t().contiguous()

# Execute shuffle
shuffled_w_nk = shuffle_weight(w_q_nk, layout=(16, 16))

del layer.weight
layer.register_buffer("weight", shuffled_w_nk)

logger.info_once("[AiterBpreshuffle: Weight shuffle COMPLETE.")

else:
raise ValueError(
f"Weight shape (N={N}, K={K}) not divisible by 16 "
"for AITER bpreshuffle."
)

def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# 1. Obtain parameters
w_q, w_s, x_s, x_s_ub = self._get_layer_params(layer)
# 2. Dynamic quantization input
qinput, qinput_scale = self.quant_fp8(x, x_s, x_s_ub)

logger.info_once(
"AiterBpreshuffle: apply_weights... ABOUT TO CALL C++ KERNEL..."
)

output = rocm_aiter_ops.gemm_a8w8_bpreshuffle(
qinput,
w_q, # Input [N, K] shuffle weights
out_dtype=self.config.out_dtype,
scale_a=qinput_scale,
scale_b=w_s,
)

logger.info_once("AiterBpreshuffle: C++ KERNEL CALL SUCCEEDED.")

if bias is not None:
output.add_(bias)
return output

def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
return rocm_aiter_ops.gemm_a8w8_bpreshuffle


class AiterCKPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
"""
AITER PTPC kernel (gemm_a8w8_CK) without pre-shuffling.
"""

def get_ouput_padding(self) -> int | None:
return None

@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return (False, "AITER CK is ROCm-only")

if not rocm_aiter_ops.is_linear_enabled():
return (False, "AITER CK is disabled by env var")

try:
import aiter # noqa: F401
except Exception:
return (False, "AITER not installed")

is_per_channel_weight = c.weight_quant_key.scale.group_shape.is_per_token()
is_per_token_activation = (
c.activation_quant_key.scale.group_shape.is_per_token()
)
is_ptpc = is_per_channel_weight and is_per_token_activation

logger.info_once(f"AiterCK: can_implement called. is_ptpc={is_ptpc}")

if not is_ptpc:
return (False, "This kernel only handles Per-Token/Per-Channel (PTPC)")

return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
logger.info_once(
"AITER CK: process_weights_after_loading... DOING NOTHING (pass)."
)
pass

def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, x_s, x_s_ub = self._get_layer_params(layer)

qinput, qinput_scale = self.quant_fp8(x, x_s, x_s_ub)

logger.info_once(
"AiterCK: apply_weights... "
"ABOUT TO CALL C++ KERNEL (this is where it hangs)..."
)

output = rocm_aiter_ops.gemm_a8w8(
qinput, w_q.t(), qinput_scale, w_s, bias, self.config.out_dtype
)

logger.info_once("AiterCK: C++ KERNEL CALL SUCCEEDED.")
return output

def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
return rocm_aiter_ops.gemm_a8w8