diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b09541dbf791..406e83bad305 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1645,6 +1645,33 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0) + def _concat_k_nope_k_pe( + self, k_nope: torch.Tensor, k_pe: torch.Tensor + ) -> torch.Tensor: + """ + Efficiently concatenate k_nope and k_pe tensors along the last dimension. + + This function avoids the performance penalty of torch.cat with expanded + non-contiguous tensors by pre-allocating the output and using direct copies. + + Args: + k_nope: Tensor of shape [..., nope_dim] + k_pe: Tensor to broadcast and concatenate, typically shape [..., 1, pe_dim] + or [..., pe_dim] + + Returns: + Tensor of shape [..., nope_dim + pe_dim] + """ + k = torch.empty( + (*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]), + dtype=k_nope.dtype, + device=k_nope.device, + ) + # Direct copies with efficient broadcasting + k[..., : k_nope.shape[-1]] = k_nope + k[..., k_nope.shape[-1] :] = k_pe + return k + def _compute_prefill_context( self, q: torch.Tensor, @@ -1681,7 +1708,7 @@ def _compute_prefill_context( ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + k = self._concat_k_nope_k_pe(k_nope, k_pe) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1785,7 +1812,7 @@ def _context_parallel_compute_prefill_context( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + k = self._concat_k_nope_k_pe(k_nope, k_pe) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1834,7 +1861,7 @@ def _forward_prefill( ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + k = self._concat_k_nope_k_pe(k_nope, k_pe) output_prefill = self._run_prefill_new_tokens( prefill=attn_metadata.prefill,