forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 5
[DO NOT MERGE] Refactor/aiter integration #76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
vllmellm
wants to merge
4
commits into
refactor-fp8-linear
Choose a base branch
from
refactor/aiter_integration
base: refactor-fp8-linear
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
45a3008
feat: Integrate AITER bpreshuffle and ck operators on top of fp8 refa…
vllmellm 45803e1
Merge remote-tracking branch 'origin/refactor-fp8-linear' into refact…
vllmellm 686b3ec
Merge remote-tracking branch 'origin/refactor-fp8-linear' into refact…
vllmellm 679a7cf
WIP: Integrate Aiter bpreshuffle and ck kernels
vllmellm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,14 +2,24 @@ | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
|
|
||
| from collections.abc import Callable | ||
|
|
||
| import torch | ||
| from aiter.ops.shuffle import shuffle_weight | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm._aiter_ops import rocm_aiter_ops | ||
| from vllm.logger import init_logger | ||
| from vllm.platforms import current_platform | ||
|
|
||
| from .cutlass import CutlassScaledMMLinearKernel | ||
| from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig | ||
| from .ScaledMMLinearKernel import ( | ||
| FP8ScaledMMLinearKernel, | ||
| FP8ScaledMMLinearLayerConfig, | ||
| Int8ScaledMMLinearLayerConfig, | ||
| ) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): | ||
|
|
@@ -115,3 +125,160 @@ def apply_weights( | |
| # b to be [N, K] | ||
| # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format | ||
| return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) | ||
|
|
||
|
|
||
| class AiterBpreshufflePerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): | ||
| def get_ouput_padding(self) -> int | None: | ||
| # PTPC kernels do not require padding. | ||
| return None | ||
|
|
||
| @classmethod | ||
| def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: | ||
| if not current_platform.is_rocm(): | ||
| return (False, "AITER bpreshuffle is ROCm-only") | ||
|
|
||
| if not rocm_aiter_ops.is_linear_enabled(): | ||
| return (False, "AITER bpreshuffle is disabled by env var") | ||
|
|
||
| try: | ||
| import aiter # noqa: F401 | ||
| except Exception: | ||
| return (False, "AITER not installed") | ||
|
|
||
| # Check if the configuration is PTPC | ||
| is_per_channel_weight = c.weight_quant_key.scale.group_shape.is_per_token() | ||
| is_per_token_activation = ( | ||
| c.activation_quant_key.scale.group_shape.is_per_token() | ||
| ) | ||
| is_ptpc = is_per_channel_weight and is_per_token_activation | ||
|
|
||
| logger.info_once(f"AiterBpreshuffle: can_implement called. is_ptpc={is_ptpc}") | ||
|
|
||
| if not is_ptpc: | ||
| return (False, "This kernel only handles Per-Token/Per-Channel (PTPC)") | ||
|
|
||
| return True, None | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| logger.info_once("AiterBpreshuffle: SHUFFLING WEIGHTS NOW.") | ||
|
|
||
| w_q, _, _, _ = self._get_layer_params(layer) | ||
|
|
||
| N = w_q.shape[1] | ||
| K = w_q.shape[0] | ||
|
|
||
| if N % 16 == 0 and K % 16 == 0: | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add https://github.com/ROCm/vllm/blob/c88d6d2ec7299605bb2ed8a4aee9260d90ef0631/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py#L153 to the rocm_aiter_ops and use that to replace this |
||
| # AITER shuffle_weight expectation [N, K] | ||
| w_q_nk = w_q.t().contiguous() | ||
|
|
||
| # Execute shuffle | ||
| shuffled_w_nk = shuffle_weight(w_q_nk, layout=(16, 16)) | ||
|
|
||
| del layer.weight | ||
| layer.register_buffer("weight", shuffled_w_nk) | ||
|
|
||
| logger.info_once("[AiterBpreshuffle: Weight shuffle COMPLETE.") | ||
|
|
||
| else: | ||
| raise ValueError( | ||
| f"Weight shape (N={N}, K={K}) not divisible by 16 " | ||
| "for AITER bpreshuffle." | ||
| ) | ||
|
|
||
| def apply_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| # 1. Obtain parameters | ||
| w_q, w_s, x_s, x_s_ub = self._get_layer_params(layer) | ||
| # 2. Dynamic quantization input | ||
| qinput, qinput_scale = self.quant_fp8(x, x_s, x_s_ub) | ||
|
|
||
| logger.info_once( | ||
| "AiterBpreshuffle: apply_weights... ABOUT TO CALL C++ KERNEL..." | ||
| ) | ||
|
|
||
| output = rocm_aiter_ops.gemm_a8w8_bpreshuffle( | ||
| qinput, | ||
| w_q, # Input [N, K] shuffle weights | ||
| out_dtype=self.config.out_dtype, | ||
| scale_a=qinput_scale, | ||
| scale_b=w_s, | ||
| ) | ||
|
|
||
| logger.info_once("AiterBpreshuffle: C++ KERNEL CALL SUCCEEDED.") | ||
|
|
||
| if bias is not None: | ||
| output.add_(bias) | ||
| return output | ||
|
|
||
| def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: | ||
| return rocm_aiter_ops.gemm_a8w8_bpreshuffle | ||
|
|
||
|
|
||
| class AiterCKPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): | ||
| """ | ||
| AITER PTPC kernel (gemm_a8w8_CK) without pre-shuffling. | ||
| """ | ||
|
|
||
| def get_ouput_padding(self) -> int | None: | ||
| return None | ||
|
|
||
| @classmethod | ||
| def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: | ||
| if not current_platform.is_rocm(): | ||
| return (False, "AITER CK is ROCm-only") | ||
|
|
||
| if not rocm_aiter_ops.is_linear_enabled(): | ||
| return (False, "AITER CK is disabled by env var") | ||
|
|
||
| try: | ||
| import aiter # noqa: F401 | ||
| except Exception: | ||
| return (False, "AITER not installed") | ||
|
|
||
| is_per_channel_weight = c.weight_quant_key.scale.group_shape.is_per_token() | ||
| is_per_token_activation = ( | ||
| c.activation_quant_key.scale.group_shape.is_per_token() | ||
| ) | ||
| is_ptpc = is_per_channel_weight and is_per_token_activation | ||
|
|
||
| logger.info_once(f"AiterCK: can_implement called. is_ptpc={is_ptpc}") | ||
|
|
||
| if not is_ptpc: | ||
| return (False, "This kernel only handles Per-Token/Per-Channel (PTPC)") | ||
|
|
||
| return True, None | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| logger.info_once( | ||
| "AITER CK: process_weights_after_loading... DOING NOTHING (pass)." | ||
| ) | ||
| pass | ||
|
|
||
| def apply_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| w_q, w_s, x_s, x_s_ub = self._get_layer_params(layer) | ||
|
|
||
| qinput, qinput_scale = self.quant_fp8(x, x_s, x_s_ub) | ||
|
|
||
| logger.info_once( | ||
| "AiterCK: apply_weights... " | ||
| "ABOUT TO CALL C++ KERNEL (this is where it hangs)..." | ||
| ) | ||
|
|
||
| output = rocm_aiter_ops.gemm_a8w8( | ||
| qinput, w_q.t(), qinput_scale, w_s, bias, self.config.out_dtype | ||
| ) | ||
|
|
||
| logger.info_once("AiterCK: C++ KERNEL CALL SUCCEEDED.") | ||
| return output | ||
|
|
||
| def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]: | ||
| return rocm_aiter_ops.gemm_a8w8 | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check if fp8_linear is initialised.