From cc69d0b079853f9da7a9b6adff56c00d95419594 Mon Sep 17 00:00:00 2001 From: ganyi Date: Tue, 14 Oct 2025 01:59:23 +0000 Subject: [PATCH 1/5] enable persistent mla kernel Signed-off-by: ganyi --- vllm/attention/ops/rocm_aiter_mla.py | 40 ++++++++ .../attention/backends/mla/rocm_aiter_mla.py | 99 ++++++++++++++++++- 2 files changed, 134 insertions(+), 5 deletions(-) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 6308f63cc4e7..ce758b3f3013 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -33,6 +33,14 @@ def aiter_mla_decode_fwd( kv_indices: torch.Tensor | None = None, kv_last_page_lens: torch.Tensor | None = None, logit_cap: float = 0.0, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, ): torch.ops.vllm.rocm_aiter_mla_decode_fwd( q, @@ -45,6 +53,14 @@ def aiter_mla_decode_fwd( kv_last_page_lens, sm_scale=sm_scale, logit_cap=logit_cap, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=q_scale, + kv_scale=kv_scale, ) @@ -59,6 +75,14 @@ def mla_decode_fwd_impl( kv_last_page_lens: torch.Tensor | None = None, sm_scale: float = 1.0, logit_cap: float = 0.0, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, ) -> None: from aiter.mla import mla_decode_fwd @@ -73,6 +97,14 @@ def mla_decode_fwd_impl( max_seqlen_qo, sm_scale=sm_scale, logit_cap=logit_cap, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=q_scale, + kv_scale=kv_scale, ) @@ -87,6 +119,14 @@ def mla_decode_fwd_fake( kv_last_page_lens: torch.Tensor | None = None, sm_scale: float = 1.0, logit_cap: float = 0.0, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, ) -> None: pass diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index d935c02243bd..49e0e077206a 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from dataclasses import dataclass -from typing import ClassVar +from typing import ClassVar, Optional, Union import torch @@ -56,6 +57,20 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): # The query indptr, shape : [num_decode + 1] qo_indptr: torch.Tensor | None = None + max_seqlen_qo: int = 1 + + work_metadata: torch.Tensor | None = None + + work_info_set: torch.Tensor | None = None + + work_indptr: torch.Tensor | None = None + + reduce_indptr: torch.Tensor | None = None + + reduce_final_map: torch.Tensor | None = None + + reduce_partial_map: torch.Tensor | None = None + class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): pass @@ -82,6 +97,10 @@ def __init__( "AITER MLAonly supports block size 1." ) + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + self.compilation_config = vllm_config.compilation_config max_num_pages_per_req = cdiv( vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size @@ -89,6 +108,36 @@ def __init__( max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req + # num_mtp = vllm_config.speculative_config.num_speculative_tokens + # num_mtp = 1 if num_mtp is None else num_mtp + max_seqlen_qo = ( + 1 + if vllm_config.speculative_config is None + else vllm_config.speculative_config.num_speculative_tokens + ) + + max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * self.num_heads / 128)) + self.work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") + self.work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + self.work_info_set = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch * cu_num, 8], + dtype=torch.int32, + device="cuda", + ).fill_(-1) + self.reduce_indptr = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch + 1], + dtype=torch.int32, + device="cuda", + ) + self.reduce_final_map = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + ) + self.reduce_partial_map = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch * cu_num], + dtype=torch.int32, + device="cuda", + ) + # Preparing persistent buffers # TODO: we can disambiguate between decode and mixed-prefill decode here # so we can only use the persistent buffer if a cudagraph is actually @@ -139,6 +188,32 @@ def _build_decode( block_table_bounds.cumsum(dim=0, dtype=torch.int32), ] ) + kv_indptr = torch.zeros( + [query_start_loc_cpu.size(0)], dtype=torch.int32, device="cuda" + ) + torch.cumsum(seq_lens_device, dim=0, out=kv_indptr[1:]) + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + max_seqlen_qo = torch.max(query_lens).item() + + import aiter + + aiter.get_mla_metadata_v1( + query_start_loc_device, + kv_indptr, + self.num_heads // self.kv_cache_spec.num_kv_heads, + self.kv_cache_spec.num_kv_heads, + True, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + kv_granularity=max(self.kv_cache_spec.block_size, 16), + max_seqlen_qo=max_seqlen_qo, + uni_seqlen_qo=max_seqlen_qo, + fast_mode=True, + ) if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): num_actual_pages = paged_kv_indices.size(0) @@ -176,6 +251,13 @@ def _build_decode( paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr, dcp_tot_seq_lens=dcp_tot_seq_lens_device, + max_seqlen_qo=max_seqlen_qo, + work_metadata=self.work_metadata, + work_info_set=self.work_info_set, + work_indptr=self.work_indptr, + reduce_indptr=self.reduce_indptr, + reduce_final_map=self.reduce_final_map, + reduce_partial_map=self.reduce_partial_map, ) return attn_metadata @@ -256,24 +338,31 @@ def _forward_decode( assert isinstance(q, torch.Tensor) B = q.shape[0] o = torch.zeros( - B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device - ) + B, self.num_heads, self.kv_lora_rank, dtype=torch.bfloat16, device=q.device + ).fill_(-1) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) # max_seqlen_qo must be 1 except for MTP # TODO: Find the best value for MTP - max_seqlen_qo = 1 aiter_mla_decode_fwd( q, kv_buffer, o, self.scale, attn_metadata.decode.qo_indptr, - max_seqlen_qo, + attn_metadata.decode.max_seqlen_qo, attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len, + work_meta_data=attn_metadata.decode.work_metadata, + work_indptr=attn_metadata.decode.work_indptr, + work_info_set=attn_metadata.decode.work_info_set, + reduce_indptr=attn_metadata.decode.reduce_indptr, + reduce_final_map=attn_metadata.decode.reduce_final_map, + reduce_partial_map=attn_metadata.decode.reduce_partial_map, + q_scale=layer._q_scale, + kv_scale=layer._k_scale, ) return o, None From 4b817cd046094ac109916711fddef3181cfb9cce Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 23 Oct 2025 06:28:16 +0000 Subject: [PATCH 2/5] fix lint Signed-off-by: ganyi --- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 49e0e077206a..c4e1f2ae3ca2 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -3,7 +3,7 @@ import math from dataclasses import dataclass -from typing import ClassVar, Optional, Union +from typing import ClassVar import torch From cb1b16c956163967d6ae656fe220948b0a441151 Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 23 Oct 2025 09:38:06 +0000 Subject: [PATCH 3/5] add QueryLenSupport Signed-off-by: ganyi --- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index c4e1f2ae3ca2..c2623766a78c 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -18,6 +18,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder, + QueryLenSupport, ) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec @@ -82,6 +83,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE ) + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN def __init__( self, From f23e28bc58718006129a9655d449ab85a7d8cc12 Mon Sep 17 00:00:00 2001 From: ganyi Date: Mon, 27 Oct 2025 07:53:36 +0000 Subject: [PATCH 4/5] cudagraph support to uniform_batch Signed-off-by: ganyi --- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index c2623766a78c..fcbc839fdf21 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -80,9 +80,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # TODO(luka, lucas): audit this as part of: # https://github.com/vllm-project/vllm/issues/22945 - cudagraph_support: ClassVar[AttentionCGSupport] = ( - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - ) + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN def __init__( From 4c22cf1544fae9160e1cf0cfdd43bff1223de294 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Wed, 29 Oct 2025 06:41:55 +0000 Subject: [PATCH 5/5] [DeepSeek R1]Use UNIFORM QueryLen MLA for MTP --- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index fcbc839fdf21..24ad59d21e53 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -81,7 +81,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # TODO(luka, lucas): audit this as part of: # https://github.com/vllm-project/vllm/issues/22945 cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH - query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM def __init__( self,