From b43f3b17062e83a967aaf1a27aed3e0771899d86 Mon Sep 17 00:00:00 2001 From: wanfengcxz <2917021186@qq.com> Date: Sun, 6 Oct 2024 23:29:41 +0800 Subject: [PATCH 01/14] [camb] support internlm2 using dlinfer --- lmdeploy/messages.py | 2 +- lmdeploy/pytorch/backends/camb/__init__.py | 2 + .../pytorch/backends/camb/apply_rotary_emb.py | 40 +++++ lmdeploy/pytorch/backends/camb/attention.py | 138 +++++++++++++++++ lmdeploy/pytorch/backends/camb/norm.py | 36 +++++ lmdeploy/pytorch/backends/camb/op_backend.py | 141 ++++++++++++++++++ .../pytorch/backends/camb/rotary_embedding.py | 111 ++++++++++++++ lmdeploy/pytorch/backends/selector.py | 3 + lmdeploy/pytorch/check_env/__init__.py | 1 + lmdeploy/pytorch/kernels/camb/__init__.py | 14 ++ .../kernels/camb/apply_rotary_pos_emb.py | 25 ++++ .../pytorch/kernels/camb/fill_kv_cache.py | 13 ++ .../pytorch/kernels/camb/pagedattention.py | 121 +++++++++++++++ lmdeploy/pytorch/kernels/camb/rms_norm.py | 14 ++ lmdeploy/pytorch/models/internlm2.py | 1 + lmdeploy/pytorch/nn/rotary_embedding.py | 4 +- lmdeploy/utils.py | 6 +- run_internlm2.py | 20 +++ 18 files changed, 689 insertions(+), 3 deletions(-) create mode 100644 lmdeploy/pytorch/backends/camb/__init__.py create mode 100644 lmdeploy/pytorch/backends/camb/apply_rotary_emb.py create mode 100644 lmdeploy/pytorch/backends/camb/attention.py create mode 100644 lmdeploy/pytorch/backends/camb/norm.py create mode 100644 lmdeploy/pytorch/backends/camb/op_backend.py create mode 100644 lmdeploy/pytorch/backends/camb/rotary_embedding.py create mode 100644 lmdeploy/pytorch/kernels/camb/__init__.py create mode 100644 lmdeploy/pytorch/kernels/camb/apply_rotary_pos_emb.py create mode 100644 lmdeploy/pytorch/kernels/camb/fill_kv_cache.py create mode 100644 lmdeploy/pytorch/kernels/camb/pagedattention.py create mode 100644 lmdeploy/pytorch/kernels/camb/rms_norm.py create mode 100644 run_internlm2.py diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 90823598ea..2f4b7eb6a4 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -291,7 +291,7 @@ def __post_init__(self): assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks' assert self.quant_policy in (0, 4, 8), 'invalid quant_policy' assert self.device_type in [ - 'cuda', 'ascend', 'maca' + 'cuda', 'ascend', 'maca', 'camb' ], (f'invalid device_type: {self.device_type}') if self.quant_policy > 0 and self.device_type != 'cuda': assert False, 'kv cache quantization only works for CUDA.' diff --git a/lmdeploy/pytorch/backends/camb/__init__.py b/lmdeploy/pytorch/backends/camb/__init__.py new file mode 100644 index 0000000000..897495c209 --- /dev/null +++ b/lmdeploy/pytorch/backends/camb/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .op_backend import CambOpsBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/camb/apply_rotary_emb.py b/lmdeploy/pytorch/backends/camb/apply_rotary_emb.py new file mode 100644 index 0000000000..f64c8487f8 --- /dev/null +++ b/lmdeploy/pytorch/backends/camb/apply_rotary_emb.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor + +from lmdeploy.pytorch.kernels.camb import apply_rotary_pos_emb + +from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl +from .attention import CambAttentionMetadata + +class CambApplyRotaryEmbImpl(ApplyRotaryEmbImpl): + """camb Apply rotary embedding implementation.""" + + def forward(self, + query: Tensor, + key: Tensor, + cos: Tensor, + sin: Tensor, + attn_metadata: CambAttentionMetadata, + inplace: bool = True): + """forward.""" + cos_sin_ids = attn_metadata.cos_sin_ids + cu_seqlens = attn_metadata.cu_seqlens + + if inplace: + q_embed = None + k_embed = None + else: + q_embed = torch.empty_like(query) + k_embed = torch.empty_like(key) + return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed, cos_sin_ids, cu_seqlens) + + +class CambApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder): + """camb Apply rotary embedding implementation builder.""" + + @staticmethod + def build(): + """build implementation.""" + return CambApplyRotaryEmbImpl() + diff --git a/lmdeploy/pytorch/backends/camb/attention.py b/lmdeploy/pytorch/backends/camb/attention.py new file mode 100644 index 0000000000..7dd2082c5d --- /dev/null +++ b/lmdeploy/pytorch/backends/camb/attention.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass +from typing import Optional, Sequence + +from torch import Tensor + +from ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata + + +@dataclass +class CambAttentionMetadata(AttentionMetadata): + kv_start_indices: Optional[Tensor] = None + block_size: int = 16 + attention_mask: Sequence[Tensor] = tuple() + is_unpaged_prefill: Optional[bool] = None + cu_seqlens: Optional[Tensor] = None + cos_sin_ids: Optional[Tensor] = None + max_q_seq_len: int = 1 + max_kv_seq_len: int = 1 + + +class CambAttentionImpl(AttentionImpl[CambAttentionMetadata]): + """camb attention implementation.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float = None, + num_kv_heads: int = None, + v_head_size: int = None, + alibi: bool = None, + sliding_window: int = None, + logit_softcapping: float = None, + **kwargs, + ): + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + v_head_size, + alibi, + sliding_window, + logit_softcapping, + **kwargs, + ) + + from lmdeploy.pytorch.kernels.camb import (fill_kv_cache, + paged_attention_fwd) + self.fill_kv_cache = fill_kv_cache + self.paged_attention_fwd = paged_attention_fwd + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + k_cache: Tensor, + v_cache: Tensor, + attn_metadata: CambAttentionMetadata, + inplace: bool = True, + ) -> Tensor: + """forward.""" + + block_offsets = attn_metadata.block_offsets + q_start_loc = attn_metadata.q_start_loc + q_seqlens = attn_metadata.q_seqlens + kv_seqlens = attn_metadata.kv_seqlens + is_decoding = attn_metadata.is_decoding + kv_start_indices = attn_metadata.kv_start_indices + block_size = attn_metadata.block_size + attn_mask = attn_metadata.attention_mask + is_unpaged_prefill = attn_metadata.is_unpaged_prefill + max_q_seq_len = attn_metadata.max_q_seq_len + max_kv_seq_len = attn_metadata.max_kv_seq_len + cu_seqlens = attn_metadata.cu_seqlens + + # fill kv cache + k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache, + kv_start_indices) + + if inplace: + attn_output = query[..., :self.v_head_size] + else: + q_shape = query.shape + o_shape = q_shape[:-1] + (self.v_head_size, ) + attn_output = query.new_empty(o_shape) + + attn_output = self.paged_attention_fwd( + query, + key, + value, + attn_output, + k_cache, + v_cache, + block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_seqlens=kv_seqlens, + max_q_seq_len=max_q_seq_len, + max_kv_seq_len=max_kv_seq_len, + is_decoding=is_decoding, + block_size=block_size, + cu_seqlens=cu_seqlens, + attn_mask=attn_mask, + is_unpaged_prefill=is_unpaged_prefill, + ) + + return attn_output + + +class CambAttentionBuilder(AttentionBuilder[CambAttentionMetadata]): + """camb attention builder.""" + + @staticmethod + def build( + num_heads: int, + head_size: int, + scale: float = None, + num_kv_heads: int = None, + v_head_size: int = None, + alibi_scale: float = None, + sliding_window: int = None, + logical_softcapping: float = None, + **kwargs, + ) -> CambAttentionImpl: + """build.""" + return CambAttentionImpl(num_heads, + head_size, + scale=scale, + num_kv_heads=num_kv_heads, + v_head_size=v_head_size, + alibi_scale=alibi_scale, + sliding_window=sliding_window, + logical_softcapping=logical_softcapping, + **kwargs) + diff --git a/lmdeploy/pytorch/backends/camb/norm.py b/lmdeploy/pytorch/backends/camb/norm.py new file mode 100644 index 0000000000..a400f84a0a --- /dev/null +++ b/lmdeploy/pytorch/backends/camb/norm.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from lmdeploy.pytorch.kernels.camb import rms_norm + +from ..norm import RMSNormBuilder, RMSNormImpl + + +class CambRMSNormImpl(RMSNormImpl): + """camb RMS norm implementation.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + self.hidden_size = hidden_size + self.eps = eps + + def forward(self, + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor = None): + """forward.""" + if residual is None: + x = rms_norm(x, weight, self.eps) + return x + else: + x, residual = rms_norm(x, weight, self.eps, residual=residual) + return x, residual + + +class CambRMSNormBuilder(RMSNormBuilder): + """camb RMS norm implementation builder.""" + + @staticmethod + def build(weight: torch.Tensor, eps: float = 1e-6): + """build.""" + return CambRMSNormImpl(weight, eps) + diff --git a/lmdeploy/pytorch/backends/camb/op_backend.py b/lmdeploy/pytorch/backends/camb/op_backend.py new file mode 100644 index 0000000000..a6e0e3fa3a --- /dev/null +++ b/lmdeploy/pytorch/backends/camb/op_backend.py @@ -0,0 +1,141 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch + +from lmdeploy.utils import get_logger + +from ..base import OpType +from ..default import DefaultOpsBackend + +logger = get_logger('lmdeploy') + + +class CambOpsBackend(DefaultOpsBackend): + """ascend layer backend.""" + + @staticmethod + def get_name() -> str: + """backend name.""" + return 'camb' + + @classmethod + def get_layer_impl_builder(cls, layer_type: OpType): + """get ascend layer builder.""" + if layer_type == OpType.Attention: + from .attention import CambAttentionBuilder + return CambAttentionBuilder + elif layer_type == OpType.ApplyRotaryEmb: + from .apply_rotary_emb import CambApplyRotaryEmbBuilder + return CambApplyRotaryEmbBuilder + elif layer_type == OpType.RMSNorm: + from .norm import CambRMSNormBuilder + return CambRMSNormBuilder + #elif layer_type == OpType.RotaryEmbedding: + # from .rotary_embedding import CambRotaryEmbeddingBuilder + # return CambRotaryEmbeddingBuilder + else: + logger.debug( + f'Op {layer_type} fallback to default implementation.') + return super().get_layer_impl_builder(layer_type) + + @staticmethod + def get_attention_metadata_cls(): + from .attention import CambAttentionMetadata + return CambAttentionMetadata + + @staticmethod + def get_k_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + #block_size, + num_heads, + block_size, + head_size, + ) + + @staticmethod + def get_v_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + #block_size, + num_heads, + block_size, + head_size, + ) + + @classmethod + def update_step_context(cls, step_context): + """update step context.""" + kv_start_indices, attention_mask = [], [] + #_, block_size, _, _ = step_context.kv_caches[0][0].shape + _, _, block_size, _ = step_context.kv_caches[0][0].shape + device = step_context.block_offsets.device + batch_size = step_context.q_start_loc.shape[0] + + is_unpaged_prefill = False + q_start_loc = step_context.q_start_loc + q_seqlens = step_context.q_seqlens + kv_seqlens = step_context.kv_seqlens.to(torch.int32) + max_q_seq_len = torch.max(q_seqlens).cpu().item() + max_kv_seq_len = torch.max(kv_seqlens).cpu().item() + + cu_seqlens = torch.zeros(batch_size+1, dtype=torch.int32, device=device) + cu_seqlens[:-1] = step_context.q_start_loc + cu_seqlens[-1] = step_context.q_seqlens.sum() + cu_seqlens_list = cu_seqlens.tolist() + + if not step_context.is_decoding: + cos_sin_ids = step_context.position_ids[0] + else: + cos_sin_ids = torch.zeros(batch_size, dtype=torch.int32, device=device) + + if not step_context.is_decoding: + is_unpaged_prefill = \ + all((step_context.q_seqlens == + step_context.kv_seqlens).tolist()) + + for i in range(batch_size): + q_seq_len = int(step_context.q_seqlens[i]) + kv_seq_len = int(step_context.kv_seqlens[i]) + history_length = kv_seq_len - q_seq_len + block_idx = history_length // block_size + block_loc = step_context.block_offsets[i][block_idx] + token_loc = history_length % block_size + for j in range(q_seq_len): + kv_start_indices.append(block_loc * block_size + token_loc) + if j == q_seq_len - 1: + break + token_loc = (token_loc + 1) % block_size + block_idx = block_idx if token_loc else block_idx + 1 + block_loc = step_context.block_offsets[i][block_idx] + kv_start_indices = torch.tensor(kv_start_indices, device=device) + + attn_meta_cls = cls.get_attention_metadata_cls() + attn_metadata = attn_meta_cls( + step_context.is_decoding, + step_context.block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_seqlens=kv_seqlens, + kv_start_indices=kv_start_indices, + block_size=block_size, + attention_mask=None, + is_unpaged_prefill=is_unpaged_prefill, + max_q_seq_len=max_q_seq_len, + max_kv_seq_len=max_kv_seq_len, + cu_seqlens=cu_seqlens, + cos_sin_ids=cos_sin_ids, + ) + + step_context.attn_metadata = attn_metadata + return step_context + diff --git a/lmdeploy/pytorch/backends/camb/rotary_embedding.py b/lmdeploy/pytorch/backends/camb/rotary_embedding.py new file mode 100644 index 0000000000..a299816063 --- /dev/null +++ b/lmdeploy/pytorch/backends/camb/rotary_embedding.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from torch import nn + +from ..rotary_embedding import RotaryEmbeddingImpl + +def _rotary_embedding_fwd(position_ids: torch.Tensor, + inv_freq: torch.Tensor, + scaling_factor: float, + mscale: float = None, + dtype: torch.dtype = None, + device_type: torch.device = None): + """rotary embedding forward.""" + if dtype is None: + dtype = torch.float16 + if device_type is None: + device_type = 'cuda' + position_ids = position_ids.float() / scaling_factor + inv_freq_expanded = inv_freq[None, :, + None].float().expand(position_ids.shape[0], + -1, 1) + position_ids_expanded = position_ids[:, None, :] + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = device_type if isinstance( + device_type, str) and device_type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() + @ position_ids_expanded.float()).transpose(1, 2) + emb = freqs.repeat(1, 1, 2) + cos = emb.cos() + sin = emb.sin() + + if mscale is not None: + cos = cos * mscale + sin = sin * mscale + + return cos.to(dtype=dtype), sin.to(dtype=dtype) + + +class RotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module): + """base rotary embedding.""" + + def __init__(self, + dim: int, + base: int = 10000, + scaling_factor: float = 1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.base = base + inv_freq = 1.0 / (self.base**(torch.arange( + 0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor): + """forward.""" + device_type = x.device.type + dtype = x.dtype + if self.inv_freq.device != x.device: + self.inv_freq = self.inv_freq.to(x.device) + return _rotary_embedding_fwd(position_ids, + self.inv_freq, + scaling_factor=self.scaling_factor, + dtype=dtype, + device_type=device_type) + +class CambRotaryEmbeddingBuilder(RotaryEmbeddingBuilder): + """rotary embedding builder.""" + + @staticmethod + def build( + dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + scaling_factor: float = 1.0, + yarn_params: YarnParameters = None, + longrope_params: LongRoPEScalingParameters = None, + llama3_params: Llama3Parameters = None, + emb_type: RopeType = RopeType.Default, + ): + """build.""" + if emb_type in (RopeType.Default, RopeType.LinearScaling): + return RotaryEmbeddingImpl(dim, base, scaling_factor) + elif emb_type == RopeType.DynamicNTKScaling: + return LlamaDynamicNTKScalingRotaryEmbedding( + dim, base, scaling_factor, max_position_embeddings) + elif emb_type == RopeType.Llama3: + return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor, + llama3_params.low_freq_factor, + llama3_params.high_freq_factor, + max_position_embeddings) + elif emb_type == RopeType.Yarn: + return YarnRotaryEmbeddingImpl(dim, + base, + scaling_factor, + max_position_embeddings, + yarn_params=yarn_params) + elif emb_type == RopeType.LongRoPEScaling: + return LongRoPEScalingRotaryEmbeddingImpl( + dim, + base, + max_position_embeddings=max_position_embeddings, + longrope_params=longrope_params, + ) + else: + raise NotImplementedError( + f'Unsupported embedding type: {emb_type}') + diff --git a/lmdeploy/pytorch/backends/selector.py b/lmdeploy/pytorch/backends/selector.py index 987730a981..01956c739d 100644 --- a/lmdeploy/pytorch/backends/selector.py +++ b/lmdeploy/pytorch/backends/selector.py @@ -18,5 +18,8 @@ def get_backend(): if device_type == 'maca': from .dlinfer import MacaOpsBackend return MacaOpsBackend + if device_type == 'camb': + from .camb import CambOpsBackend + return CambOpsBackend else: raise RuntimeError(f'Unsupported device type: {device_type}') diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 7d72438224..2d3ad73439 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -33,6 +33,7 @@ def try_import_deeplink(device_type: str): 'ascend', 'npu', 'maca', + 'camb', ] if device_type in deeplink_device_type_list: logger = get_logger('lmdeploy') diff --git a/lmdeploy/pytorch/kernels/camb/__init__.py b/lmdeploy/pytorch/kernels/camb/__init__.py new file mode 100644 index 0000000000..96ebf88b2a --- /dev/null +++ b/lmdeploy/pytorch/kernels/camb/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..default import multinomial_sampling +from .apply_rotary_pos_emb import apply_rotary_pos_emb +from .fill_kv_cache import fill_kv_cache +from .pagedattention import paged_attention_fwd +from .rms_norm import rms_norm + +__all__ = [ + 'rms_norm', + 'apply_rotary_pos_emb', + 'fill_kv_cache', + 'paged_attention_fwd', +] + diff --git a/lmdeploy/pytorch/kernels/camb/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/camb/apply_rotary_pos_emb.py new file mode 100644 index 0000000000..478613f08a --- /dev/null +++ b/lmdeploy/pytorch/kernels/camb/apply_rotary_pos_emb.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +from torch import Tensor + + +def apply_rotary_pos_emb( + query_states: Tensor, + key_states: Tensor, + cos: Tensor, + sin: Tensor, + q_embed=None, + k_embed=None, + cos_sin_ids=None, + cu_seqlens=None, +): + query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None, cos_sin_ids, cu_seqlens) + if q_embed is None or q_embed.data_ptr() == query_states.data_ptr(): + q_embed = query_states + else: + q_embed.copy_(query_states) + if k_embed is None or k_embed.data_ptr() == key_states.data_ptr(): + k_embed = key_states + else: + k_embed.copy_(key_states) + return q_embed, k_embed diff --git a/lmdeploy/pytorch/kernels/camb/fill_kv_cache.py b/lmdeploy/pytorch/kernels/camb/fill_kv_cache.py new file mode 100644 index 0000000000..483dcb3779 --- /dev/null +++ b/lmdeploy/pytorch/kernels/camb/fill_kv_cache.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +from torch import Tensor + +def fill_kv_cache( + key_states: Tensor, + value_states: Tensor, + key_caches: Tensor, + value_caches: Tensor, + kv_start_indices: Tensor, +): + """fill key/value state to cache for paged attention.""" + return ext_ops.fill_kv_cache(key_states, value_states, key_caches, value_caches, kv_start_indices) diff --git a/lmdeploy/pytorch/kernels/camb/pagedattention.py b/lmdeploy/pytorch/kernels/camb/pagedattention.py new file mode 100644 index 0000000000..e7e33de0da --- /dev/null +++ b/lmdeploy/pytorch/kernels/camb/pagedattention.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +import math +import torch +from dlinfer.utils.type_annotation import Optional, Sequence, Tensor + +def prefill_attention( + query_states: Tensor, + key_states: Tensor, + value_states: Tensor, + attn_output: Tensor, + key_cache: Tensor, + value_cache: Tensor, + block_offsets: Tensor, + q_start_loc: Tensor, + q_seq_len: Tensor, + kv_seq_len: Tensor, + max_q_seq_len: int, + block_size: int, + cu_seqlens: Tensor, + attn_mask: Sequence[Optional[Tensor]], + is_unpaged_prefill: Optional[bool], +): + num_q_heads = query_states.shape[1] + num_kv_heads = key_states.shape[1] + + if is_unpaged_prefill: + output = torch.empty_like(query_states) + ext_ops.prefill_attention( + query_states, + key_states, + value_states, + cu_seqlens, + q_seq_len, + max_q_seq_len, + num_q_heads, + num_kv_heads, + attn_mask, + softmax_scale=1.0, + attn_output=output) + attn_output.copy_(output) + return attn_output + else: + pass + +def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, + max_kv_seq_len, block_offsets, block_size): + num_q_heads = q.shape[1] + num_kv_heads = k_cache.shape[1] + q = q.unsqueeze(1) + #attn_output = attn_output.unsqueeze(1) + + max_kv_seq_len = torch.max(kv_seq_len) + + ret = ext_ops.paged_decode_attention( + q, + k_cache, + v_cache, + block_offsets, + block_size, + kv_seq_len, + max_kv_seq_len, + num_q_heads, + num_kv_heads, + softmax_scale = 1. / math.sqrt(q.shape[-1]), + attn_output = q, + ) + + return q + + +def paged_attention_fwd( + query_states: Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attn_output: Tensor, + key_cache: Tensor, + value_cache: Tensor, + block_offsets: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_seqlens: Tensor, + max_q_seq_len: int, + max_kv_seq_len: int, + is_decoding: bool, + block_size: int, + cu_seqlens: Tensor, + attn_mask: Sequence[Optional[Tensor]] = (), + is_unpaged_prefill: Optional[bool] = None, +): + if not is_decoding: + return prefill_attention( + query_states, + key_states, + value_states, + attn_output, + key_cache, + value_cache, + block_offsets, + q_start_loc, + q_seqlens, + kv_seqlens, + max_q_seq_len, + block_size, + cu_seqlens, + attn_mask, + is_unpaged_prefill, + ) + + else: + return paged_token_attention( + query_states, + key_cache, + value_cache, + attn_output, + kv_seqlens, + max_kv_seq_len, + block_offsets, + block_size, + ) + diff --git a/lmdeploy/pytorch/kernels/camb/rms_norm.py b/lmdeploy/pytorch/kernels/camb/rms_norm.py new file mode 100644 index 0000000000..47aa556361 --- /dev/null +++ b/lmdeploy/pytorch/kernels/camb/rms_norm.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +from torch import Tensor + +def rms_norm(hidden_states: Tensor, weight: Tensor, epsilon: float = 1e-6, residual: Tensor = None, out: Tensor = None): + if residual is None: + rms_norm_out = ext_ops.rms_norm(hidden_states, weight, epsilon) + if out is None: + out = rms_norm_out + else: + out.copy_(rms_norm_out) + return out + else: + return ext_ops.add_rms_norm(hidden_states, residual, weight, epsilon) diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 497090afc5..06754c9871 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -82,6 +82,7 @@ def forward( key_states, cos, sin, + attn_metadata, inplace=True, ) diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index 43eb1f806d..9b1dd3eea7 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -3,6 +3,7 @@ from transformers import PretrainedConfig from ..backends import OpType, get_backend +from ..backends.attention import AttentionMetadata from ..backends.rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters, RopeType, YarnParameters) @@ -122,6 +123,7 @@ def forward(self, key: Tensor, cos: Tensor, sin: Tensor, + attn_metadata: AttentionMetadata, inplace: bool = True): """forward.""" - return self.impl.forward(query, key, cos, sin, inplace) + return self.impl.forward(query, key, cos, sin, attn_metadata, inplace) diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index fbdd374f80..c065e5ff76 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -332,7 +332,7 @@ def get_max_batch_size(device_type: str): Args: device_type (str): the type of device """ - assert device_type in ['cuda', 'ascend', 'maca'] + assert device_type in ['cuda', 'ascend', 'maca', 'camb'] if device_type == 'cuda': max_batch_size_map = { 'a100': 256, @@ -352,6 +352,8 @@ def get_max_batch_size(device_type: str): return 16 elif device_type == 'maca': return 128 + elif device_type == 'camb': + return 16 def is_bf16_supported(device_type: str = 'cuda'): @@ -387,5 +389,7 @@ def is_bf16_supported(device_type: str = 'cuda'): # return False elif device_type == 'maca': return True + elif device_type == 'camb': + return True else: return False diff --git a/run_internlm2.py b/run_internlm2.py new file mode 100644 index 0000000000..5621ccbce3 --- /dev/null +++ b/run_internlm2.py @@ -0,0 +1,20 @@ +# import dlinfer +import lmdeploy +import torch +from lmdeploy import PytorchEngineConfig +if __name__ == "__main__": + # torch.set_printoptions(precision=10) + b = PytorchEngineConfig(tp=1,block_size=16, cache_max_entry_count=0.4, device_type="camb") + pipe = lmdeploy.pipeline("/root/.cache/modelscope/hub/Shanghai_AI_Laboratory/internlm2-chat-7b", + backend_config = b) + #question = ["Hi, pls intro yourself", "Please introduce Shanghai."] + question = ["Hi, pls intro yourself", "Hi, pls intro yourself"] + #question = ["Hi, pls intro yourself in detail"] + print(question) + response = pipe(question, do_preprocess=False, top_k=1) + print(response) + # for idx, r in enumerate(response): + # print(f"Q: {question[idx]}") + # print(f"AAAAAA: {r.text}") + # print() + # print("end") From dc2d9e858dbd9ba241a7af3b6b6283aae1244303 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 12 Oct 2024 17:25:25 +0800 Subject: [PATCH 02/14] change --- lmdeploy/pytorch/backends/camb/attention.py | 12 +++++------- lmdeploy/pytorch/backends/camb/op_backend.py | 10 +++++----- lmdeploy/pytorch/kernels/camb/fill_kv_cache.py | 3 ++- .../pytorch/kernels/camb/pagedattention.py | 5 ++--- lmdeploy/pytorch/models/internlm2.py | 3 +++ run_internlm2.py | 18 ++++++++++++------ use_modelscope.sh | 1 + 7 files changed, 30 insertions(+), 22 deletions(-) create mode 100644 use_modelscope.sh diff --git a/lmdeploy/pytorch/backends/camb/attention.py b/lmdeploy/pytorch/backends/camb/attention.py index 7dd2082c5d..6942346c35 100644 --- a/lmdeploy/pytorch/backends/camb/attention.py +++ b/lmdeploy/pytorch/backends/camb/attention.py @@ -17,8 +17,7 @@ class CambAttentionMetadata(AttentionMetadata): cos_sin_ids: Optional[Tensor] = None max_q_seq_len: int = 1 max_kv_seq_len: int = 1 - - + class CambAttentionImpl(AttentionImpl[CambAttentionMetadata]): """camb attention implementation.""" @@ -45,7 +44,6 @@ def __init__( logit_softcapping, **kwargs, ) - from lmdeploy.pytorch.kernels.camb import (fill_kv_cache, paged_attention_fwd) self.fill_kv_cache = fill_kv_cache @@ -62,7 +60,6 @@ def forward( inplace: bool = True, ) -> Tensor: """forward.""" - block_offsets = attn_metadata.block_offsets q_start_loc = attn_metadata.q_start_loc q_seqlens = attn_metadata.q_seqlens @@ -77,8 +74,9 @@ def forward( cu_seqlens = attn_metadata.cu_seqlens # fill kv cache - k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache, - kv_start_indices) + # k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache, + # kv_start_indices) + self.fill_kv_cache(key, value, k_cache, v_cache, kv_start_indices) if inplace: attn_output = query[..., :self.v_head_size] @@ -86,7 +84,7 @@ def forward( q_shape = query.shape o_shape = q_shape[:-1] + (self.v_head_size, ) attn_output = query.new_empty(o_shape) - + attn_output = self.paged_attention_fwd( query, key, diff --git a/lmdeploy/pytorch/backends/camb/op_backend.py b/lmdeploy/pytorch/backends/camb/op_backend.py index a6e0e3fa3a..9992b11e21 100644 --- a/lmdeploy/pytorch/backends/camb/op_backend.py +++ b/lmdeploy/pytorch/backends/camb/op_backend.py @@ -12,7 +12,7 @@ class CambOpsBackend(DefaultOpsBackend): - """ascend layer backend.""" + """Camb layer backend.""" @staticmethod def get_name() -> str: @@ -21,7 +21,7 @@ def get_name() -> str: @classmethod def get_layer_impl_builder(cls, layer_type: OpType): - """get ascend layer builder.""" + """get Camb layer builder.""" if layer_type == OpType.Attention: from .attention import CambAttentionBuilder return CambAttentionBuilder @@ -94,7 +94,7 @@ def update_step_context(cls, step_context): cu_seqlens_list = cu_seqlens.tolist() if not step_context.is_decoding: - cos_sin_ids = step_context.position_ids[0] + cos_sin_ids = step_context.position_ids[0].to(torch.int32) else: cos_sin_ids = torch.zeros(batch_size, dtype=torch.int32, device=device) @@ -117,12 +117,12 @@ def update_step_context(cls, step_context): token_loc = (token_loc + 1) % block_size block_idx = block_idx if token_loc else block_idx + 1 block_loc = step_context.block_offsets[i][block_idx] - kv_start_indices = torch.tensor(kv_start_indices, device=device) + kv_start_indices = torch.tensor(kv_start_indices, device=device, dtype=torch.int32) attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls( step_context.is_decoding, - step_context.block_offsets, + step_context.block_offsets.to(torch.int32), q_start_loc=q_start_loc, q_seqlens=q_seqlens, kv_seqlens=kv_seqlens, diff --git a/lmdeploy/pytorch/kernels/camb/fill_kv_cache.py b/lmdeploy/pytorch/kernels/camb/fill_kv_cache.py index 483dcb3779..9448c1d4a4 100644 --- a/lmdeploy/pytorch/kernels/camb/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/camb/fill_kv_cache.py @@ -10,4 +10,5 @@ def fill_kv_cache( kv_start_indices: Tensor, ): """fill key/value state to cache for paged attention.""" - return ext_ops.fill_kv_cache(key_states, value_states, key_caches, value_caches, kv_start_indices) + ext_ops.fill_kv_cache(key_states, value_states, key_caches, value_caches, kv_start_indices) + diff --git a/lmdeploy/pytorch/kernels/camb/pagedattention.py b/lmdeploy/pytorch/kernels/camb/pagedattention.py index e7e33de0da..d91f504705 100644 --- a/lmdeploy/pytorch/kernels/camb/pagedattention.py +++ b/lmdeploy/pytorch/kernels/camb/pagedattention.py @@ -36,7 +36,8 @@ def prefill_attention( num_q_heads, num_kv_heads, attn_mask, - softmax_scale=1.0, + # softmax_scale=1.0, + softmax_scale = 1. / math.sqrt(query_states.shape[-1]), attn_output=output) attn_output.copy_(output) return attn_output @@ -51,7 +52,6 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, #attn_output = attn_output.unsqueeze(1) max_kv_seq_len = torch.max(kv_seq_len) - ret = ext_ops.paged_decode_attention( q, k_cache, @@ -65,7 +65,6 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, softmax_scale = 1. / math.sqrt(q.shape[-1]), attn_output = q, ) - return q diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 06754c9871..9193fed56f 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -195,12 +195,15 @@ def forward( hidden_states, residual) # Self Attention + # print("hidden_states before:",torch.mean(hidden_states)) hidden_states = self.attention( hidden_states=hidden_states, rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, attn_metadata=attn_metadata, ) + # print("hidden_states after:",torch.mean(hidden_states)) + # Fully Connected hidden_states, residual = self.ffn_norm(hidden_states, residual) diff --git a/run_internlm2.py b/run_internlm2.py index 5621ccbce3..198d51175c 100644 --- a/run_internlm2.py +++ b/run_internlm2.py @@ -3,13 +3,19 @@ import torch from lmdeploy import PytorchEngineConfig if __name__ == "__main__": + seed = 1024 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # torch.set_printoptions(precision=10) - b = PytorchEngineConfig(tp=1,block_size=16, cache_max_entry_count=0.4, device_type="camb") - pipe = lmdeploy.pipeline("/root/.cache/modelscope/hub/Shanghai_AI_Laboratory/internlm2-chat-7b", - backend_config = b) - #question = ["Hi, pls intro yourself", "Please introduce Shanghai."] - question = ["Hi, pls intro yourself", "Hi, pls intro yourself"] - #question = ["Hi, pls intro yourself in detail"] + b = PytorchEngineConfig(tp=1,block_size=16, cache_max_entry_count=0.4, device_type="camb", download_dir="/workspace/volume/shangda/share/llm_models") + pipe = lmdeploy.pipeline("Shanghai_AI_Laboratory/internlm2_5-7b", + backend_config = b) + # pipe = lmdeploy.pipeline("Shanghai_AI_Laboratory/internlm2-chat-7b", + # backend_config = b) + # question = ["Hi, pls intro yourself", "Please introduce Shanghai."] + # question = ["Hi, pls intro yourself", "Hi, pls intro yourself"] + question = ["Hi, pls intro yourself", "who is your father"] print(question) response = pipe(question, do_preprocess=False, top_k=1) print(response) diff --git a/use_modelscope.sh b/use_modelscope.sh new file mode 100644 index 0000000000..7c04f57c80 --- /dev/null +++ b/use_modelscope.sh @@ -0,0 +1 @@ +export LMDEPLOY_USE_MODELSCOPE=True From 0010dcf5807e4a42057847d08f84c8d785b146d6 Mon Sep 17 00:00:00 2001 From: wanfengcxz Date: Wed, 16 Oct 2024 11:37:23 +0800 Subject: [PATCH 03/14] change --- lmdeploy/pytorch/kernels/camb/pagedattention.py | 4 +--- lmdeploy/pytorch/models/internlm2.py | 15 ++++++--------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/lmdeploy/pytorch/kernels/camb/pagedattention.py b/lmdeploy/pytorch/kernels/camb/pagedattention.py index d91f504705..7fc05280bc 100644 --- a/lmdeploy/pytorch/kernels/camb/pagedattention.py +++ b/lmdeploy/pytorch/kernels/camb/pagedattention.py @@ -49,9 +49,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, num_q_heads = q.shape[1] num_kv_heads = k_cache.shape[1] q = q.unsqueeze(1) - #attn_output = attn_output.unsqueeze(1) - - max_kv_seq_len = torch.max(kv_seq_len) + ret = ext_ops.paged_decode_attention( q, k_cache, diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 9193fed56f..e50afb7292 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -74,7 +74,7 @@ def forward( qkv_states = qkv_states.flatten(0, -2) query_states, key_states, value_states = self.wqkv.split_qkv( qkv_states) - + # apply rotary embedding cos, sin = rotary_pos_emb query_states, key_states = self.apply_rotary_pos_emb( @@ -86,7 +86,7 @@ def forward( inplace=True, ) - # attention + #attention attn_output = self.attn_fwd( query_states, key_states, @@ -138,14 +138,14 @@ def __init__(self, dtype=dtype, device=device, is_tp=True) - + + @torch.profiler.record_function("MLP") def forward(self, x): """forward.""" gate_up = self.gate_up_proj(x) act = self.act_fn(gate_up) return self.w2(act) - class InternLM2DecoderLayer(nn.Module): """decoder layer.""" @@ -186,7 +186,7 @@ def forward( residual: Optional[torch.Tensor] = None, attn_metadata: Any = None, ): - + if residual is None: residual = hidden_states hidden_states = self.attention_norm(hidden_states) @@ -195,15 +195,12 @@ def forward( hidden_states, residual) # Self Attention - # print("hidden_states before:",torch.mean(hidden_states)) hidden_states = self.attention( hidden_states=hidden_states, rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, attn_metadata=attn_metadata, ) - # print("hidden_states after:",torch.mean(hidden_states)) - # Fully Connected hidden_states, residual = self.ffn_norm(hidden_states, residual) @@ -306,7 +303,7 @@ def forward( # norm hidden_states, _ = self.norm(hidden_states, residual) - + return hidden_states def get_input_embeddings(self): From 119781f0466b51804b216d2f73b916b3e63b449c Mon Sep 17 00:00:00 2001 From: wanfengcxz <2917021186@qq.com> Date: Thu, 17 Oct 2024 17:28:18 +0800 Subject: [PATCH 04/14] [camb] add silu_and_mul --- lmdeploy/pytorch/backends/camb/activation.py | 22 ++++++++++++++++++++ lmdeploy/pytorch/backends/camb/op_backend.py | 6 +++--- lmdeploy/pytorch/kernels/camb/activation.py | 9 ++++++++ 3 files changed, 34 insertions(+), 3 deletions(-) create mode 100644 lmdeploy/pytorch/backends/camb/activation.py create mode 100644 lmdeploy/pytorch/kernels/camb/activation.py diff --git a/lmdeploy/pytorch/backends/camb/activation.py b/lmdeploy/pytorch/backends/camb/activation.py new file mode 100644 index 0000000000..d195d73e7f --- /dev/null +++ b/lmdeploy/pytorch/backends/camb/activation.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import nn + +from lmdeploy.pytorch.kernels.camb.activation import silu_and_mul + +from ..activation import (GeluAndMulBuilder, GeluAndMulImpl, SiluAndMulBuilder, + SiluAndMulImpl) + +class CambSiluAndMulImpl(SiluAndMulImpl): + """silu + multiple fused implementation.""" + + def forward(self, x): + """forward.""" + return silu_and_mul(x) + +class CambSiluAndMulBuilder(SiluAndMulBuilder): + """silu and mul implementation builder.""" + + @staticmethod + def build(inplace: bool = False): + """build.""" + return CambSiluAndMulImpl() diff --git a/lmdeploy/pytorch/backends/camb/op_backend.py b/lmdeploy/pytorch/backends/camb/op_backend.py index 9992b11e21..d83d8ef2fb 100644 --- a/lmdeploy/pytorch/backends/camb/op_backend.py +++ b/lmdeploy/pytorch/backends/camb/op_backend.py @@ -31,9 +31,9 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.RMSNorm: from .norm import CambRMSNormBuilder return CambRMSNormBuilder - #elif layer_type == OpType.RotaryEmbedding: - # from .rotary_embedding import CambRotaryEmbeddingBuilder - # return CambRotaryEmbeddingBuilder + elif layer_type == OpType.SiluAndMul: + from .activation import CambSiluAndMulBuilder + return CambSiluAndMulBuilder else: logger.debug( f'Op {layer_type} fallback to default implementation.') diff --git a/lmdeploy/pytorch/kernels/camb/activation.py b/lmdeploy/pytorch/kernels/camb/activation.py new file mode 100644 index 0000000000..2d2298d1ec --- /dev/null +++ b/lmdeploy/pytorch/kernels/camb/activation.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +from torch import Tensor + + +def silu_and_mul( + input_tensor: Tensor, +) -> Tensor: + return ext_ops.silu_and_mul(input_tensor) From fdfacb92e21e0fe08aab12db796c776be40eb4d3 Mon Sep 17 00:00:00 2001 From: wanfengcxz <2917021186@qq.com> Date: Thu, 24 Oct 2024 12:58:07 +0800 Subject: [PATCH 05/14] refactor for multiple devices in dlinfer(camb) --- lmdeploy/pytorch/backends/camb/activation.py | 22 --- .../pytorch/backends/camb/apply_rotary_emb.py | 40 ------ lmdeploy/pytorch/backends/camb/attention.py | 136 ------------------ lmdeploy/pytorch/backends/camb/norm.py | 36 ----- .../pytorch/backends/camb/rotary_embedding.py | 111 -------------- lmdeploy/pytorch/backends/dlinfer/__init__.py | 1 + .../backends/dlinfer/apply_rotary_emb.py | 8 +- .../pytorch/backends/dlinfer/attention.py | 6 +- .../backends/{ => dlinfer}/camb/__init__.py | 0 .../backends/{ => dlinfer}/camb/op_backend.py | 35 +---- lmdeploy/pytorch/backends/selector.py | 2 +- lmdeploy/pytorch/kernels/camb/__init__.py | 14 -- lmdeploy/pytorch/kernels/camb/activation.py | 9 -- .../kernels/camb/apply_rotary_pos_emb.py | 25 ---- .../pytorch/kernels/camb/fill_kv_cache.py | 14 -- .../pytorch/kernels/camb/pagedattention.py | 118 --------------- lmdeploy/pytorch/kernels/camb/rms_norm.py | 14 -- lmdeploy/pytorch/kernels/dlinfer/__init__.py | 2 + .../kernels/dlinfer/apply_rotary_pos_emb.py | 20 +-- .../pytorch/kernels/dlinfer/pagedattention.py | 19 ++- lmdeploy/pytorch/models/internlm2.py | 13 +- run_internlm2.py | 26 ---- use_modelscope.sh | 1 - 23 files changed, 45 insertions(+), 627 deletions(-) delete mode 100644 lmdeploy/pytorch/backends/camb/activation.py delete mode 100644 lmdeploy/pytorch/backends/camb/apply_rotary_emb.py delete mode 100644 lmdeploy/pytorch/backends/camb/attention.py delete mode 100644 lmdeploy/pytorch/backends/camb/norm.py delete mode 100644 lmdeploy/pytorch/backends/camb/rotary_embedding.py rename lmdeploy/pytorch/backends/{ => dlinfer}/camb/__init__.py (100%) rename lmdeploy/pytorch/backends/{ => dlinfer}/camb/op_backend.py (73%) delete mode 100644 lmdeploy/pytorch/kernels/camb/__init__.py delete mode 100644 lmdeploy/pytorch/kernels/camb/activation.py delete mode 100644 lmdeploy/pytorch/kernels/camb/apply_rotary_pos_emb.py delete mode 100644 lmdeploy/pytorch/kernels/camb/fill_kv_cache.py delete mode 100644 lmdeploy/pytorch/kernels/camb/pagedattention.py delete mode 100644 lmdeploy/pytorch/kernels/camb/rms_norm.py delete mode 100644 run_internlm2.py delete mode 100644 use_modelscope.sh diff --git a/lmdeploy/pytorch/backends/camb/activation.py b/lmdeploy/pytorch/backends/camb/activation.py deleted file mode 100644 index d195d73e7f..0000000000 --- a/lmdeploy/pytorch/backends/camb/activation.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from torch import nn - -from lmdeploy.pytorch.kernels.camb.activation import silu_and_mul - -from ..activation import (GeluAndMulBuilder, GeluAndMulImpl, SiluAndMulBuilder, - SiluAndMulImpl) - -class CambSiluAndMulImpl(SiluAndMulImpl): - """silu + multiple fused implementation.""" - - def forward(self, x): - """forward.""" - return silu_and_mul(x) - -class CambSiluAndMulBuilder(SiluAndMulBuilder): - """silu and mul implementation builder.""" - - @staticmethod - def build(inplace: bool = False): - """build.""" - return CambSiluAndMulImpl() diff --git a/lmdeploy/pytorch/backends/camb/apply_rotary_emb.py b/lmdeploy/pytorch/backends/camb/apply_rotary_emb.py deleted file mode 100644 index f64c8487f8..0000000000 --- a/lmdeploy/pytorch/backends/camb/apply_rotary_emb.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -from torch import Tensor - -from lmdeploy.pytorch.kernels.camb import apply_rotary_pos_emb - -from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl -from .attention import CambAttentionMetadata - -class CambApplyRotaryEmbImpl(ApplyRotaryEmbImpl): - """camb Apply rotary embedding implementation.""" - - def forward(self, - query: Tensor, - key: Tensor, - cos: Tensor, - sin: Tensor, - attn_metadata: CambAttentionMetadata, - inplace: bool = True): - """forward.""" - cos_sin_ids = attn_metadata.cos_sin_ids - cu_seqlens = attn_metadata.cu_seqlens - - if inplace: - q_embed = None - k_embed = None - else: - q_embed = torch.empty_like(query) - k_embed = torch.empty_like(key) - return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed, cos_sin_ids, cu_seqlens) - - -class CambApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder): - """camb Apply rotary embedding implementation builder.""" - - @staticmethod - def build(): - """build implementation.""" - return CambApplyRotaryEmbImpl() - diff --git a/lmdeploy/pytorch/backends/camb/attention.py b/lmdeploy/pytorch/backends/camb/attention.py deleted file mode 100644 index 6942346c35..0000000000 --- a/lmdeploy/pytorch/backends/camb/attention.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from dataclasses import dataclass -from typing import Optional, Sequence - -from torch import Tensor - -from ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata - - -@dataclass -class CambAttentionMetadata(AttentionMetadata): - kv_start_indices: Optional[Tensor] = None - block_size: int = 16 - attention_mask: Sequence[Tensor] = tuple() - is_unpaged_prefill: Optional[bool] = None - cu_seqlens: Optional[Tensor] = None - cos_sin_ids: Optional[Tensor] = None - max_q_seq_len: int = 1 - max_kv_seq_len: int = 1 - -class CambAttentionImpl(AttentionImpl[CambAttentionMetadata]): - """camb attention implementation.""" - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float = None, - num_kv_heads: int = None, - v_head_size: int = None, - alibi: bool = None, - sliding_window: int = None, - logit_softcapping: float = None, - **kwargs, - ): - super().__init__( - num_heads, - head_size, - scale, - num_kv_heads, - v_head_size, - alibi, - sliding_window, - logit_softcapping, - **kwargs, - ) - from lmdeploy.pytorch.kernels.camb import (fill_kv_cache, - paged_attention_fwd) - self.fill_kv_cache = fill_kv_cache - self.paged_attention_fwd = paged_attention_fwd - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - k_cache: Tensor, - v_cache: Tensor, - attn_metadata: CambAttentionMetadata, - inplace: bool = True, - ) -> Tensor: - """forward.""" - block_offsets = attn_metadata.block_offsets - q_start_loc = attn_metadata.q_start_loc - q_seqlens = attn_metadata.q_seqlens - kv_seqlens = attn_metadata.kv_seqlens - is_decoding = attn_metadata.is_decoding - kv_start_indices = attn_metadata.kv_start_indices - block_size = attn_metadata.block_size - attn_mask = attn_metadata.attention_mask - is_unpaged_prefill = attn_metadata.is_unpaged_prefill - max_q_seq_len = attn_metadata.max_q_seq_len - max_kv_seq_len = attn_metadata.max_kv_seq_len - cu_seqlens = attn_metadata.cu_seqlens - - # fill kv cache - # k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache, - # kv_start_indices) - self.fill_kv_cache(key, value, k_cache, v_cache, kv_start_indices) - - if inplace: - attn_output = query[..., :self.v_head_size] - else: - q_shape = query.shape - o_shape = q_shape[:-1] + (self.v_head_size, ) - attn_output = query.new_empty(o_shape) - - attn_output = self.paged_attention_fwd( - query, - key, - value, - attn_output, - k_cache, - v_cache, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seqlens, - kv_seqlens=kv_seqlens, - max_q_seq_len=max_q_seq_len, - max_kv_seq_len=max_kv_seq_len, - is_decoding=is_decoding, - block_size=block_size, - cu_seqlens=cu_seqlens, - attn_mask=attn_mask, - is_unpaged_prefill=is_unpaged_prefill, - ) - - return attn_output - - -class CambAttentionBuilder(AttentionBuilder[CambAttentionMetadata]): - """camb attention builder.""" - - @staticmethod - def build( - num_heads: int, - head_size: int, - scale: float = None, - num_kv_heads: int = None, - v_head_size: int = None, - alibi_scale: float = None, - sliding_window: int = None, - logical_softcapping: float = None, - **kwargs, - ) -> CambAttentionImpl: - """build.""" - return CambAttentionImpl(num_heads, - head_size, - scale=scale, - num_kv_heads=num_kv_heads, - v_head_size=v_head_size, - alibi_scale=alibi_scale, - sliding_window=sliding_window, - logical_softcapping=logical_softcapping, - **kwargs) - diff --git a/lmdeploy/pytorch/backends/camb/norm.py b/lmdeploy/pytorch/backends/camb/norm.py deleted file mode 100644 index a400f84a0a..0000000000 --- a/lmdeploy/pytorch/backends/camb/norm.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - -from lmdeploy.pytorch.kernels.camb import rms_norm - -from ..norm import RMSNormBuilder, RMSNormImpl - - -class CambRMSNormImpl(RMSNormImpl): - """camb RMS norm implementation.""" - - def __init__(self, hidden_size: int, eps: float = 1e-6): - self.hidden_size = hidden_size - self.eps = eps - - def forward(self, - x: torch.Tensor, - weight: torch.Tensor, - residual: torch.Tensor = None): - """forward.""" - if residual is None: - x = rms_norm(x, weight, self.eps) - return x - else: - x, residual = rms_norm(x, weight, self.eps, residual=residual) - return x, residual - - -class CambRMSNormBuilder(RMSNormBuilder): - """camb RMS norm implementation builder.""" - - @staticmethod - def build(weight: torch.Tensor, eps: float = 1e-6): - """build.""" - return CambRMSNormImpl(weight, eps) - diff --git a/lmdeploy/pytorch/backends/camb/rotary_embedding.py b/lmdeploy/pytorch/backends/camb/rotary_embedding.py deleted file mode 100644 index a299816063..0000000000 --- a/lmdeploy/pytorch/backends/camb/rotary_embedding.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import math - -import torch -from torch import nn - -from ..rotary_embedding import RotaryEmbeddingImpl - -def _rotary_embedding_fwd(position_ids: torch.Tensor, - inv_freq: torch.Tensor, - scaling_factor: float, - mscale: float = None, - dtype: torch.dtype = None, - device_type: torch.device = None): - """rotary embedding forward.""" - if dtype is None: - dtype = torch.float16 - if device_type is None: - device_type = 'cuda' - position_ids = position_ids.float() / scaling_factor - inv_freq_expanded = inv_freq[None, :, - None].float().expand(position_ids.shape[0], - -1, 1) - position_ids_expanded = position_ids[:, None, :] - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = device_type if isinstance( - device_type, str) and device_type != 'mps' else 'cpu' - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() - @ position_ids_expanded.float()).transpose(1, 2) - emb = freqs.repeat(1, 1, 2) - cos = emb.cos() - sin = emb.sin() - - if mscale is not None: - cos = cos * mscale - sin = sin * mscale - - return cos.to(dtype=dtype), sin.to(dtype=dtype) - - -class RotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module): - """base rotary embedding.""" - - def __init__(self, - dim: int, - base: int = 10000, - scaling_factor: float = 1.0): - super().__init__() - self.scaling_factor = scaling_factor - self.dim = dim - self.base = base - inv_freq = 1.0 / (self.base**(torch.arange( - 0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer('inv_freq', inv_freq, persistent=False) - - def forward(self, x: torch.Tensor, position_ids: torch.Tensor): - """forward.""" - device_type = x.device.type - dtype = x.dtype - if self.inv_freq.device != x.device: - self.inv_freq = self.inv_freq.to(x.device) - return _rotary_embedding_fwd(position_ids, - self.inv_freq, - scaling_factor=self.scaling_factor, - dtype=dtype, - device_type=device_type) - -class CambRotaryEmbeddingBuilder(RotaryEmbeddingBuilder): - """rotary embedding builder.""" - - @staticmethod - def build( - dim: int, - max_position_embeddings: int = 2048, - base: int = 10000, - scaling_factor: float = 1.0, - yarn_params: YarnParameters = None, - longrope_params: LongRoPEScalingParameters = None, - llama3_params: Llama3Parameters = None, - emb_type: RopeType = RopeType.Default, - ): - """build.""" - if emb_type in (RopeType.Default, RopeType.LinearScaling): - return RotaryEmbeddingImpl(dim, base, scaling_factor) - elif emb_type == RopeType.DynamicNTKScaling: - return LlamaDynamicNTKScalingRotaryEmbedding( - dim, base, scaling_factor, max_position_embeddings) - elif emb_type == RopeType.Llama3: - return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor, - llama3_params.low_freq_factor, - llama3_params.high_freq_factor, - max_position_embeddings) - elif emb_type == RopeType.Yarn: - return YarnRotaryEmbeddingImpl(dim, - base, - scaling_factor, - max_position_embeddings, - yarn_params=yarn_params) - elif emb_type == RopeType.LongRoPEScaling: - return LongRoPEScalingRotaryEmbeddingImpl( - dim, - base, - max_position_embeddings=max_position_embeddings, - longrope_params=longrope_params, - ) - else: - raise NotImplementedError( - f'Unsupported embedding type: {emb_type}') - diff --git a/lmdeploy/pytorch/backends/dlinfer/__init__.py b/lmdeploy/pytorch/backends/dlinfer/__init__.py index af3ccff085..b1d2382936 100644 --- a/lmdeploy/pytorch/backends/dlinfer/__init__.py +++ b/lmdeploy/pytorch/backends/dlinfer/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ascend import AscendOpsBackend # noqa: F401 from .maca import MacaOpsBackend # noqa: F401 +from .camb import CambOpsBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py b/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py index c2bc1f7dce..d47dea00d6 100644 --- a/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py +++ b/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py @@ -5,7 +5,7 @@ from lmdeploy.pytorch.kernels.dlinfer import apply_rotary_pos_emb from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl - +from .attention import DlinferAttentionMetadata class DlinferApplyRotaryEmbImpl(ApplyRotaryEmbImpl): """Apply rotary embedding implementation.""" @@ -15,15 +15,19 @@ def forward(self, key: Tensor, cos: Tensor, sin: Tensor, + attn_metadata: DlinferAttentionMetadata, inplace: bool = True): """forward.""" + cos_sin_ids = attn_metadata.cos_sin_ids + cu_seqlens = attn_metadata.cu_seqlens + if inplace: q_embed = None k_embed = None else: q_embed = torch.empty_like(query) k_embed = torch.empty_like(key) - return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed) + return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed, cos_sin_ids, cu_seqlens) class DlinferApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder): diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 0d666c9130..070edf6187 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -10,11 +10,13 @@ @dataclass class DlinferAttentionMetadata(AttentionMetadata): kv_start_indices: Optional[Tensor] = None - block_size: int = 64 + block_size: int = 16 attention_mask: Sequence[Tensor] = tuple() is_unpaged_prefill: Optional[bool] = None max_q_seq_len: int = 1 max_kv_seq_len: int = 1 + cu_seqlens: Optional[Tensor] = None + cos_sin_ids: Optional[Tensor] = None class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): @@ -74,6 +76,7 @@ def forward( is_unpaged_prefill = attn_metadata.is_unpaged_prefill max_q_seq_len = attn_metadata.max_q_seq_len max_kv_seq_len = attn_metadata.max_kv_seq_len + cu_seqlens = attn_metadata.cu_seqlens # fill kv cache k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache, @@ -101,6 +104,7 @@ def forward( max_kv_seq_len=max_kv_seq_len, is_decoding=is_decoding, block_size=block_size, + cu_seqlens=cu_seqlens, attn_mask=attn_mask, is_unpaged_prefill=is_unpaged_prefill, ) diff --git a/lmdeploy/pytorch/backends/camb/__init__.py b/lmdeploy/pytorch/backends/dlinfer/camb/__init__.py similarity index 100% rename from lmdeploy/pytorch/backends/camb/__init__.py rename to lmdeploy/pytorch/backends/dlinfer/camb/__init__.py diff --git a/lmdeploy/pytorch/backends/camb/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py similarity index 73% rename from lmdeploy/pytorch/backends/camb/op_backend.py rename to lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py index d83d8ef2fb..901717349f 100644 --- a/lmdeploy/pytorch/backends/camb/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py @@ -5,45 +5,19 @@ from lmdeploy.utils import get_logger -from ..base import OpType -from ..default import DefaultOpsBackend +from ..op_backend import DlinferOpsBackend logger = get_logger('lmdeploy') -class CambOpsBackend(DefaultOpsBackend): - """Camb layer backend.""" +class CambOpsBackend(DlinferOpsBackend): + """camb layer backend.""" @staticmethod def get_name() -> str: """backend name.""" return 'camb' - @classmethod - def get_layer_impl_builder(cls, layer_type: OpType): - """get Camb layer builder.""" - if layer_type == OpType.Attention: - from .attention import CambAttentionBuilder - return CambAttentionBuilder - elif layer_type == OpType.ApplyRotaryEmb: - from .apply_rotary_emb import CambApplyRotaryEmbBuilder - return CambApplyRotaryEmbBuilder - elif layer_type == OpType.RMSNorm: - from .norm import CambRMSNormBuilder - return CambRMSNormBuilder - elif layer_type == OpType.SiluAndMul: - from .activation import CambSiluAndMulBuilder - return CambSiluAndMulBuilder - else: - logger.debug( - f'Op {layer_type} fallback to default implementation.') - return super().get_layer_impl_builder(layer_type) - - @staticmethod - def get_attention_metadata_cls(): - from .attention import CambAttentionMetadata - return CambAttentionMetadata - @staticmethod def get_k_block_shape( block_size: int, @@ -52,7 +26,6 @@ def get_k_block_shape( dtype: torch.dtype, ) -> Tuple[int, ...]: return ( - #block_size, num_heads, block_size, head_size, @@ -66,7 +39,6 @@ def get_v_block_shape( dtype: torch.dtype, ) -> Tuple[int, ...]: return ( - #block_size, num_heads, block_size, head_size, @@ -76,7 +48,6 @@ def get_v_block_shape( def update_step_context(cls, step_context): """update step context.""" kv_start_indices, attention_mask = [], [] - #_, block_size, _, _ = step_context.kv_caches[0][0].shape _, _, block_size, _ = step_context.kv_caches[0][0].shape device = step_context.block_offsets.device batch_size = step_context.q_start_loc.shape[0] diff --git a/lmdeploy/pytorch/backends/selector.py b/lmdeploy/pytorch/backends/selector.py index 01956c739d..4db73fa370 100644 --- a/lmdeploy/pytorch/backends/selector.py +++ b/lmdeploy/pytorch/backends/selector.py @@ -19,7 +19,7 @@ def get_backend(): from .dlinfer import MacaOpsBackend return MacaOpsBackend if device_type == 'camb': - from .camb import CambOpsBackend + from .dlinfer import CambOpsBackend return CambOpsBackend else: raise RuntimeError(f'Unsupported device type: {device_type}') diff --git a/lmdeploy/pytorch/kernels/camb/__init__.py b/lmdeploy/pytorch/kernels/camb/__init__.py deleted file mode 100644 index 96ebf88b2a..0000000000 --- a/lmdeploy/pytorch/kernels/camb/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from ..default import multinomial_sampling -from .apply_rotary_pos_emb import apply_rotary_pos_emb -from .fill_kv_cache import fill_kv_cache -from .pagedattention import paged_attention_fwd -from .rms_norm import rms_norm - -__all__ = [ - 'rms_norm', - 'apply_rotary_pos_emb', - 'fill_kv_cache', - 'paged_attention_fwd', -] - diff --git a/lmdeploy/pytorch/kernels/camb/activation.py b/lmdeploy/pytorch/kernels/camb/activation.py deleted file mode 100644 index 2d2298d1ec..0000000000 --- a/lmdeploy/pytorch/kernels/camb/activation.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import dlinfer.ops as ext_ops -from torch import Tensor - - -def silu_and_mul( - input_tensor: Tensor, -) -> Tensor: - return ext_ops.silu_and_mul(input_tensor) diff --git a/lmdeploy/pytorch/kernels/camb/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/camb/apply_rotary_pos_emb.py deleted file mode 100644 index 478613f08a..0000000000 --- a/lmdeploy/pytorch/kernels/camb/apply_rotary_pos_emb.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import dlinfer.ops as ext_ops -from torch import Tensor - - -def apply_rotary_pos_emb( - query_states: Tensor, - key_states: Tensor, - cos: Tensor, - sin: Tensor, - q_embed=None, - k_embed=None, - cos_sin_ids=None, - cu_seqlens=None, -): - query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None, cos_sin_ids, cu_seqlens) - if q_embed is None or q_embed.data_ptr() == query_states.data_ptr(): - q_embed = query_states - else: - q_embed.copy_(query_states) - if k_embed is None or k_embed.data_ptr() == key_states.data_ptr(): - k_embed = key_states - else: - k_embed.copy_(key_states) - return q_embed, k_embed diff --git a/lmdeploy/pytorch/kernels/camb/fill_kv_cache.py b/lmdeploy/pytorch/kernels/camb/fill_kv_cache.py deleted file mode 100644 index 9448c1d4a4..0000000000 --- a/lmdeploy/pytorch/kernels/camb/fill_kv_cache.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import dlinfer.ops as ext_ops -from torch import Tensor - -def fill_kv_cache( - key_states: Tensor, - value_states: Tensor, - key_caches: Tensor, - value_caches: Tensor, - kv_start_indices: Tensor, -): - """fill key/value state to cache for paged attention.""" - ext_ops.fill_kv_cache(key_states, value_states, key_caches, value_caches, kv_start_indices) - diff --git a/lmdeploy/pytorch/kernels/camb/pagedattention.py b/lmdeploy/pytorch/kernels/camb/pagedattention.py deleted file mode 100644 index 7fc05280bc..0000000000 --- a/lmdeploy/pytorch/kernels/camb/pagedattention.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import dlinfer.ops as ext_ops -import math -import torch -from dlinfer.utils.type_annotation import Optional, Sequence, Tensor - -def prefill_attention( - query_states: Tensor, - key_states: Tensor, - value_states: Tensor, - attn_output: Tensor, - key_cache: Tensor, - value_cache: Tensor, - block_offsets: Tensor, - q_start_loc: Tensor, - q_seq_len: Tensor, - kv_seq_len: Tensor, - max_q_seq_len: int, - block_size: int, - cu_seqlens: Tensor, - attn_mask: Sequence[Optional[Tensor]], - is_unpaged_prefill: Optional[bool], -): - num_q_heads = query_states.shape[1] - num_kv_heads = key_states.shape[1] - - if is_unpaged_prefill: - output = torch.empty_like(query_states) - ext_ops.prefill_attention( - query_states, - key_states, - value_states, - cu_seqlens, - q_seq_len, - max_q_seq_len, - num_q_heads, - num_kv_heads, - attn_mask, - # softmax_scale=1.0, - softmax_scale = 1. / math.sqrt(query_states.shape[-1]), - attn_output=output) - attn_output.copy_(output) - return attn_output - else: - pass - -def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, - max_kv_seq_len, block_offsets, block_size): - num_q_heads = q.shape[1] - num_kv_heads = k_cache.shape[1] - q = q.unsqueeze(1) - - ret = ext_ops.paged_decode_attention( - q, - k_cache, - v_cache, - block_offsets, - block_size, - kv_seq_len, - max_kv_seq_len, - num_q_heads, - num_kv_heads, - softmax_scale = 1. / math.sqrt(q.shape[-1]), - attn_output = q, - ) - return q - - -def paged_attention_fwd( - query_states: Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - attn_output: Tensor, - key_cache: Tensor, - value_cache: Tensor, - block_offsets: Tensor, - q_start_loc: Tensor, - q_seqlens: Tensor, - kv_seqlens: Tensor, - max_q_seq_len: int, - max_kv_seq_len: int, - is_decoding: bool, - block_size: int, - cu_seqlens: Tensor, - attn_mask: Sequence[Optional[Tensor]] = (), - is_unpaged_prefill: Optional[bool] = None, -): - if not is_decoding: - return prefill_attention( - query_states, - key_states, - value_states, - attn_output, - key_cache, - value_cache, - block_offsets, - q_start_loc, - q_seqlens, - kv_seqlens, - max_q_seq_len, - block_size, - cu_seqlens, - attn_mask, - is_unpaged_prefill, - ) - - else: - return paged_token_attention( - query_states, - key_cache, - value_cache, - attn_output, - kv_seqlens, - max_kv_seq_len, - block_offsets, - block_size, - ) - diff --git a/lmdeploy/pytorch/kernels/camb/rms_norm.py b/lmdeploy/pytorch/kernels/camb/rms_norm.py deleted file mode 100644 index 47aa556361..0000000000 --- a/lmdeploy/pytorch/kernels/camb/rms_norm.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import dlinfer.ops as ext_ops -from torch import Tensor - -def rms_norm(hidden_states: Tensor, weight: Tensor, epsilon: float = 1e-6, residual: Tensor = None, out: Tensor = None): - if residual is None: - rms_norm_out = ext_ops.rms_norm(hidden_states, weight, epsilon) - if out is None: - out = rms_norm_out - else: - out.copy_(rms_norm_out) - return out - else: - return ext_ops.add_rms_norm(hidden_states, residual, weight, epsilon) diff --git a/lmdeploy/pytorch/kernels/dlinfer/__init__.py b/lmdeploy/pytorch/kernels/dlinfer/__init__.py index 8f86f0019a..0014a88e4d 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/__init__.py +++ b/lmdeploy/pytorch/kernels/dlinfer/__init__.py @@ -8,6 +8,7 @@ from .moe_gating_topk_softmax import moe_gating_topk_softmax from .pagedattention import paged_attention_fwd from .rms_norm import rms_norm +from .activation import silu_and_mul __all__ = [ 'rms_norm', @@ -19,4 +20,5 @@ 'linear', 'moe_gating_topk_softmax', 'multinomial_sampling', + 'silu_and_mul', ] diff --git a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py index 0f13f3f38c..94e42a6b1d 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py @@ -15,24 +15,16 @@ def apply_rotary_pos_emb( ) -> Tuple[Tensor, Tensor]: query_states = query_states.contiguous() key_states = key_states.contiguous() - bs = query_states.shape[0] - query_states_reshaped = query_states.unsqueeze(0) - key_states_reshaped = key_states.unsqueeze(0) - cos_reshaped = cos.reshape(1, bs, 1, -1) - sin_reshaped = sin.reshape(1, bs, 1, -1) - query_states_reshaped, key_states_reshaped = \ - ext_ops.apply_rotary_pos_emb(query_states_reshaped, - key_states_reshaped, - cos_reshaped, sin_reshaped, - None, None) + query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None, cos_sin_ids, cu_seqlens) + if q_embed is None: - q_embed = query_states_reshaped.view(query_states.shape) + q_embed = query_states elif q_embed is not query_states: - q_embed.copy_(query_states_reshaped.view(query_states.shape)) + q_embed.copy_(query_states) if k_embed is None: - k_embed = key_states_reshaped.view(key_states.shape) + k_embed = key_states elif k_embed is not key_states: - k_embed.copy_(key_states_reshaped.view(key_states.shape)) + k_embed.copy_(key_states) return q_embed, k_embed diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index 21c72074a4..267a96740e 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -17,6 +17,7 @@ def prefill_attention( kv_seq_len: Tensor, max_q_seq_len: int, block_size: int, + cu_seqlens: Tensor, attn_mask: Sequence[Optional[Tensor]], is_unpaged_prefill: Optional[bool], ) -> Tensor: @@ -24,18 +25,21 @@ def prefill_attention( num_kv_heads = value_states.shape[1] if is_unpaged_prefill: - return ext_ops.prefill_attention( + output = torch.empty_like(query_states) + ext_ops.prefill_attention( query_states, key_states, value_states, - q_start_loc, + cu_seqlens, q_seq_len, max_q_seq_len, num_q_heads, num_kv_heads, attn_mask, - attn_output=attn_output, + attn_output=output, ) + attn_output.copy_(output) + return attn_output else: return ext_ops.paged_prefill_attention( query_states, @@ -55,8 +59,11 @@ def prefill_attention( def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, max_kv_seq_len, block_offsets, block_size): - num_q_heads, q_head_dim = q.shape[1:3] - num_kv_heads = k_cache.shape[-1] // q_head_dim + num_q_heads = q.shape[1] + num_kv_heads = k_cache.shape[1] + q = q.unsqueeze(1) + attn_output = attn_output.unsqueeze(1) + return ext_ops.paged_decode_attention( q, k_cache, @@ -86,6 +93,7 @@ def paged_attention_fwd( max_kv_seq_len: int, is_decoding: bool, block_size: int, + cu_seqlens: Tensor, attn_mask: Sequence[Optional[Tensor]] = (), is_unpaged_prefill: Optional[bool] = None, ): @@ -103,6 +111,7 @@ def paged_attention_fwd( kv_seqlens, max_q_seq_len, block_size, + cu_seqlens, attn_mask, is_unpaged_prefill, ) diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index e50afb7292..cf0706659a 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -74,7 +74,7 @@ def forward( qkv_states = qkv_states.flatten(0, -2) query_states, key_states, value_states = self.wqkv.split_qkv( qkv_states) - + # apply rotary embedding cos, sin = rotary_pos_emb query_states, key_states = self.apply_rotary_pos_emb( @@ -86,7 +86,7 @@ def forward( inplace=True, ) - #attention + # attention attn_output = self.attn_fwd( query_states, key_states, @@ -138,14 +138,14 @@ def __init__(self, dtype=dtype, device=device, is_tp=True) - - @torch.profiler.record_function("MLP") + def forward(self, x): """forward.""" gate_up = self.gate_up_proj(x) act = self.act_fn(gate_up) return self.w2(act) + class InternLM2DecoderLayer(nn.Module): """decoder layer.""" @@ -186,7 +186,7 @@ def forward( residual: Optional[torch.Tensor] = None, attn_metadata: Any = None, ): - + if residual is None: residual = hidden_states hidden_states = self.attention_norm(hidden_states) @@ -303,7 +303,7 @@ def forward( # norm hidden_states, _ = self.norm(hidden_states, residual) - + return hidden_states def get_input_embeddings(self): @@ -440,3 +440,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) + diff --git a/run_internlm2.py b/run_internlm2.py deleted file mode 100644 index 198d51175c..0000000000 --- a/run_internlm2.py +++ /dev/null @@ -1,26 +0,0 @@ -# import dlinfer -import lmdeploy -import torch -from lmdeploy import PytorchEngineConfig -if __name__ == "__main__": - seed = 1024 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - # torch.set_printoptions(precision=10) - b = PytorchEngineConfig(tp=1,block_size=16, cache_max_entry_count=0.4, device_type="camb", download_dir="/workspace/volume/shangda/share/llm_models") - pipe = lmdeploy.pipeline("Shanghai_AI_Laboratory/internlm2_5-7b", - backend_config = b) - # pipe = lmdeploy.pipeline("Shanghai_AI_Laboratory/internlm2-chat-7b", - # backend_config = b) - # question = ["Hi, pls intro yourself", "Please introduce Shanghai."] - # question = ["Hi, pls intro yourself", "Hi, pls intro yourself"] - question = ["Hi, pls intro yourself", "who is your father"] - print(question) - response = pipe(question, do_preprocess=False, top_k=1) - print(response) - # for idx, r in enumerate(response): - # print(f"Q: {question[idx]}") - # print(f"AAAAAA: {r.text}") - # print() - # print("end") diff --git a/use_modelscope.sh b/use_modelscope.sh deleted file mode 100644 index 7c04f57c80..0000000000 --- a/use_modelscope.sh +++ /dev/null @@ -1 +0,0 @@ -export LMDEPLOY_USE_MODELSCOPE=True From a4a31d390ed581307e45f45ac1c300f9965dd812 Mon Sep 17 00:00:00 2001 From: wanfengcxz <2917021186@qq.com> Date: Thu, 24 Oct 2024 15:36:44 +0800 Subject: [PATCH 06/14] refactor and remove pos_id in apply_rotary_emb --- lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py | 8 ++------ lmdeploy/pytorch/backends/dlinfer/attention.py | 9 ++++++--- lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py | 8 +------- lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py | 2 +- lmdeploy/pytorch/kernels/dlinfer/pagedattention.py | 7 ++----- lmdeploy/pytorch/models/internlm2.py | 3 ++- lmdeploy/pytorch/nn/rotary_embedding.py | 5 ++--- 7 files changed, 16 insertions(+), 26 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py b/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py index d47dea00d6..f474ca0804 100644 --- a/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py +++ b/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py @@ -5,7 +5,6 @@ from lmdeploy.pytorch.kernels.dlinfer import apply_rotary_pos_emb from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl -from .attention import DlinferAttentionMetadata class DlinferApplyRotaryEmbImpl(ApplyRotaryEmbImpl): """Apply rotary embedding implementation.""" @@ -15,19 +14,16 @@ def forward(self, key: Tensor, cos: Tensor, sin: Tensor, - attn_metadata: DlinferAttentionMetadata, + cu_seqlens: Tensor, inplace: bool = True): """forward.""" - cos_sin_ids = attn_metadata.cos_sin_ids - cu_seqlens = attn_metadata.cu_seqlens - if inplace: q_embed = None k_embed = None else: q_embed = torch.empty_like(query) k_embed = torch.empty_like(key) - return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed, cos_sin_ids, cu_seqlens) + return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed, cu_seqlens) class DlinferApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder): diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 070edf6187..52fd3d6c80 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -10,14 +10,13 @@ @dataclass class DlinferAttentionMetadata(AttentionMetadata): kv_start_indices: Optional[Tensor] = None - block_size: int = 16 + block_size: int = 64 attention_mask: Sequence[Tensor] = tuple() is_unpaged_prefill: Optional[bool] = None max_q_seq_len: int = 1 max_kv_seq_len: int = 1 cu_seqlens: Optional[Tensor] = None - cos_sin_ids: Optional[Tensor] = None - + is_flash_attn_support_inplace: bool = True class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): """dlinfer attention implementation.""" @@ -82,6 +81,10 @@ def forward( k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache, kv_start_indices) + if is_unpaged_prefill: + inplace = inplace if attn_metadata.is_flash_attn_support_inplace \ + else False + if inplace: attn_output = query[..., :self.v_head_size] else: diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py index 901717349f..bb03e54b4f 100644 --- a/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py @@ -62,12 +62,6 @@ def update_step_context(cls, step_context): cu_seqlens = torch.zeros(batch_size+1, dtype=torch.int32, device=device) cu_seqlens[:-1] = step_context.q_start_loc cu_seqlens[-1] = step_context.q_seqlens.sum() - cu_seqlens_list = cu_seqlens.tolist() - - if not step_context.is_decoding: - cos_sin_ids = step_context.position_ids[0].to(torch.int32) - else: - cos_sin_ids = torch.zeros(batch_size, dtype=torch.int32, device=device) if not step_context.is_decoding: is_unpaged_prefill = \ @@ -104,7 +98,7 @@ def update_step_context(cls, step_context): max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, cu_seqlens=cu_seqlens, - cos_sin_ids=cos_sin_ids, + is_flash_attn_support_inplace=False, ) step_context.attn_metadata = attn_metadata diff --git a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py index 94e42a6b1d..c936c5cbf9 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py @@ -15,7 +15,7 @@ def apply_rotary_pos_emb( ) -> Tuple[Tensor, Tensor]: query_states = query_states.contiguous() key_states = key_states.contiguous() - query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None, cos_sin_ids, cu_seqlens) + query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None, cu_seqlens) if q_embed is None: q_embed = query_states diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index 267a96740e..e1b257764e 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -25,8 +25,7 @@ def prefill_attention( num_kv_heads = value_states.shape[1] if is_unpaged_prefill: - output = torch.empty_like(query_states) - ext_ops.prefill_attention( + return ext_ops.prefill_attention( query_states, key_states, value_states, @@ -36,10 +35,8 @@ def prefill_attention( num_q_heads, num_kv_heads, attn_mask, - attn_output=output, + attn_output=attn_output, ) - attn_output.copy_(output) - return attn_output else: return ext_ops.paged_prefill_attention( query_states, diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index cf0706659a..a88ee64dd9 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -75,6 +75,7 @@ def forward( query_states, key_states, value_states = self.wqkv.split_qkv( qkv_states) + cu_seqlens = attn_metadata.cu_seqlens # apply rotary embedding cos, sin = rotary_pos_emb query_states, key_states = self.apply_rotary_pos_emb( @@ -82,7 +83,7 @@ def forward( key_states, cos, sin, - attn_metadata, + cu_seqlens, inplace=True, ) diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index 9b1dd3eea7..f8e435825e 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -3,7 +3,6 @@ from transformers import PretrainedConfig from ..backends import OpType, get_backend -from ..backends.attention import AttentionMetadata from ..backends.rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters, RopeType, YarnParameters) @@ -123,7 +122,7 @@ def forward(self, key: Tensor, cos: Tensor, sin: Tensor, - attn_metadata: AttentionMetadata, + cu_seqlens: Tensor, inplace: bool = True): """forward.""" - return self.impl.forward(query, key, cos, sin, attn_metadata, inplace) + return self.impl.forward(query, key, cos, sin, cu_seqlens, inplace) From 1fae399da23ec73bd9288ffd08934e8de19e8e10 Mon Sep 17 00:00:00 2001 From: wanfengcxz <2917021186@qq.com> Date: Fri, 25 Oct 2024 00:27:08 +0800 Subject: [PATCH 07/14] refactor(camb) --- .../backends/dlinfer/apply_rotary_emb.py | 3 +-- .../pytorch/backends/dlinfer/attention.py | 6 +++++- .../backends/dlinfer/camb/op_backend.py | 1 + .../kernels/dlinfer/apply_rotary_pos_emb.py | 2 +- .../pytorch/kernels/dlinfer/pagedattention.py | 19 ++----------------- lmdeploy/pytorch/models/internlm2.py | 2 -- lmdeploy/pytorch/nn/rotary_embedding.py | 3 +-- 7 files changed, 11 insertions(+), 25 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py b/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py index f474ca0804..5d5c6b034a 100644 --- a/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py +++ b/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py @@ -14,7 +14,6 @@ def forward(self, key: Tensor, cos: Tensor, sin: Tensor, - cu_seqlens: Tensor, inplace: bool = True): """forward.""" if inplace: @@ -23,7 +22,7 @@ def forward(self, else: q_embed = torch.empty_like(query) k_embed = torch.empty_like(key) - return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed, cu_seqlens) + return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed) class DlinferApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder): diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 52fd3d6c80..c667c86387 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -17,6 +17,7 @@ class DlinferAttentionMetadata(AttentionMetadata): max_kv_seq_len: int = 1 cu_seqlens: Optional[Tensor] = None is_flash_attn_support_inplace: bool = True + is_mock_q_start_loc: bool = False class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): """dlinfer attention implementation.""" @@ -76,6 +77,7 @@ def forward( max_q_seq_len = attn_metadata.max_q_seq_len max_kv_seq_len = attn_metadata.max_kv_seq_len cu_seqlens = attn_metadata.cu_seqlens + is_mock_q_start_loc = attn_metadata.is_mock_q_start_loc # fill kv cache k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache, @@ -85,6 +87,9 @@ def forward( inplace = inplace if attn_metadata.is_flash_attn_support_inplace \ else False + if is_mock_q_start_loc: + q_start_loc = cu_seqlens + if inplace: attn_output = query[..., :self.v_head_size] else: @@ -107,7 +112,6 @@ def forward( max_kv_seq_len=max_kv_seq_len, is_decoding=is_decoding, block_size=block_size, - cu_seqlens=cu_seqlens, attn_mask=attn_mask, is_unpaged_prefill=is_unpaged_prefill, ) diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py index bb03e54b4f..c003983e2f 100644 --- a/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py @@ -99,6 +99,7 @@ def update_step_context(cls, step_context): max_kv_seq_len=max_kv_seq_len, cu_seqlens=cu_seqlens, is_flash_attn_support_inplace=False, + is_mock_q_start_loc=True, ) step_context.attn_metadata = attn_metadata diff --git a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py index c936c5cbf9..d714241ed3 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py @@ -15,7 +15,7 @@ def apply_rotary_pos_emb( ) -> Tuple[Tensor, Tensor]: query_states = query_states.contiguous() key_states = key_states.contiguous() - query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None, cu_seqlens) + query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None) if q_embed is None: q_embed = query_states diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index e1b257764e..98996e8e5a 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -17,23 +17,17 @@ def prefill_attention( kv_seq_len: Tensor, max_q_seq_len: int, block_size: int, - cu_seqlens: Tensor, attn_mask: Sequence[Optional[Tensor]], is_unpaged_prefill: Optional[bool], -) -> Tensor: - num_q_heads = query_states.shape[1] - num_kv_heads = value_states.shape[1] - +): if is_unpaged_prefill: return ext_ops.prefill_attention( query_states, key_states, value_states, - cu_seqlens, + q_start_loc, q_seq_len, max_q_seq_len, - num_q_heads, - num_kv_heads, attn_mask, attn_output=attn_output, ) @@ -56,11 +50,6 @@ def prefill_attention( def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, max_kv_seq_len, block_offsets, block_size): - num_q_heads = q.shape[1] - num_kv_heads = k_cache.shape[1] - q = q.unsqueeze(1) - attn_output = attn_output.unsqueeze(1) - return ext_ops.paged_decode_attention( q, k_cache, @@ -69,8 +58,6 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, block_size, kv_seq_len, max_kv_seq_len, - num_q_heads, - num_kv_heads, attn_output=attn_output, ) @@ -90,7 +77,6 @@ def paged_attention_fwd( max_kv_seq_len: int, is_decoding: bool, block_size: int, - cu_seqlens: Tensor, attn_mask: Sequence[Optional[Tensor]] = (), is_unpaged_prefill: Optional[bool] = None, ): @@ -108,7 +94,6 @@ def paged_attention_fwd( kv_seqlens, max_q_seq_len, block_size, - cu_seqlens, attn_mask, is_unpaged_prefill, ) diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index a88ee64dd9..b9434fcd05 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -75,7 +75,6 @@ def forward( query_states, key_states, value_states = self.wqkv.split_qkv( qkv_states) - cu_seqlens = attn_metadata.cu_seqlens # apply rotary embedding cos, sin = rotary_pos_emb query_states, key_states = self.apply_rotary_pos_emb( @@ -83,7 +82,6 @@ def forward( key_states, cos, sin, - cu_seqlens, inplace=True, ) diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index f8e435825e..43eb1f806d 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -122,7 +122,6 @@ def forward(self, key: Tensor, cos: Tensor, sin: Tensor, - cu_seqlens: Tensor, inplace: bool = True): """forward.""" - return self.impl.forward(query, key, cos, sin, cu_seqlens, inplace) + return self.impl.forward(query, key, cos, sin, inplace) From 3f8f2a74da9273bab5997d7f147ee1b2fcd40c9d Mon Sep 17 00:00:00 2001 From: root Date: Mon, 28 Oct 2024 16:13:11 +0800 Subject: [PATCH 08/14] support camb graph --- .../backends/dlinfer/camb/graph_runner.py | 326 ++++++++++++++++++ .../backends/dlinfer/camb/op_backend.py | 11 + 2 files changed, 337 insertions(+) create mode 100644 lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py new file mode 100644 index 0000000000..ba64313cb8 --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py @@ -0,0 +1,326 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Tuple + +import torch +import torch_mlu +from torch_mlu.utils.model_transfer import transfer +from torch import Tensor + +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.pytorch.model_inputs import StepContext +from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta +from lmdeploy.utils import get_logger + +from ...graph_runner import GraphRunner + +logger = get_logger('lmdeploy') + +BuffType = Dict[str, Tensor] + +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n.""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n + + +def _false(*args, **kwargs): + """default value of not support cuda graph.""" + return False + + +class CAMBSingleGraphRunner: + """camb single graph runner.""" + + def __init__( + self, + model: torch.nn.Module, + max_batches: int, + max_tokens: int, + num_blocks: int, + is_decoding: bool, + pool: Tuple[int, int], + device: torch.device, + ): + self.model = model + self.ctx_mgr = model.ctx_mgr + self.meta = CudaGraphMeta( + max_batchs=max_batches, + max_tokens=max_tokens, + num_blocks=num_blocks, + is_decoding=is_decoding, + device=device, + input_buffers=dict(), + output_buffers=dict(), + ) + self.device = device + self.max_batches = max_batches + self.max_tokens = max_tokens + self.num_blocks = num_blocks + self.is_decoding = is_decoding + self.pool = pool + self._graph: torch.cuda.CUDAGraph = None + + def capture(self, **kwargs): + """capture graph.""" + self.meta.input_buffers = self.make_Camb_buffers( + self.meta, **kwargs) + # padded_kwargs = self.model.fill_buffers_cudagraph(self.meta, **kwargs) + padded_kwargs = self.update_Camb_buffer(self.meta, **kwargs) + + context = self.ctx_mgr.current_context() + # self.model.update_context_cudagraph(self.meta, context) + self.update_Camb_context(self.meta, context) + current_stream = torch.cuda.current_stream() + # warmup + self.model(**padded_kwargs) + + self._graph = torch.cuda.CUDAGraph() + # unsafe kernel call in other thread might invalid the capture + # so we set thread_safe capture mode here. + with torch.cuda.graph(self._graph, + pool=self.pool, + stream=current_stream, + capture_error_mode='thread_local'): + output = self.model(**padded_kwargs) + + output_buffers = dict(logits=output) + self.meta.output_buffers = output_buffers + return output + + def make_Camb_buffers(self, graph_meta: CudaGraphMeta, *args, + **kwargs) -> BuffType: + """make cudagraph buffers from forward inputs.""" + max_batches = graph_meta.max_batchs + max_tokens = graph_meta.max_tokens + num_blocks = graph_meta.num_blocks + device = graph_meta.device + + input_buffers: BuffType = dict() + input_buffers['input_ids'] = torch.zeros(1, + max_tokens, + dtype=torch.int32, + device=device) + input_buffers['position_ids'] = torch.zeros((1, max_tokens), + dtype=torch.int32, + device=device) + + input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), + dtype=torch.int32, + device=device) + input_buffers['q_start_loc'] = torch.zeros(max_batches, + dtype=torch.int32, + device=device) + input_buffers['q_seqlens'] = torch.zeros(max_batches, + dtype=torch.int32, + device=device) + input_buffers['kv_seqlens'] = torch.zeros(max_batches, + dtype=torch.int32, + device=device) + + input_buffers['cu_seqlens'] = torch.zeros(max_batches+1, + dtype=torch.int32, + device=device) + input_buffers['kv_start_indices'] = torch.ones(max_batches*max_tokens, + dtype=torch.int32, + device=device) * 512 + + input_buffers['local_adapter_ids'] = torch.zeros(max_batches, + dtype=torch.int32, + device=device) + return input_buffers + + def update_Camb_buffer(self, graph_meta: CudaGraphMeta, + input_ids: Tensor, position_ids: Tensor, + past_key_values: List, attn_metadata: Any, + inputs_embeds: Tensor, + **kwargs) -> Dict[str, Tensor]: + """fill cudagraph buffers from forward inputs.""" + is_decoding = graph_meta.is_decoding + block_offsets: Tensor = attn_metadata.block_offsets + q_start_loc: Tensor = attn_metadata.q_start_loc + q_seqlens: Tensor = attn_metadata.q_seqlens + kv_seqlens: Tensor = attn_metadata.kv_seqlens + + cu_seqlens: Tensor = attn_metadata.cu_seqlens + kv_start_indices: Tensor = attn_metadata.kv_start_indices + + input_buffers: BuffType = graph_meta.input_buffers + + batch_size, num_blocks = block_offsets.size() + num_tokens = input_ids.size(-1) + # fill buffer + input_buffers['input_ids'][:, :num_tokens] = input_ids + input_buffers['position_ids'][:, :num_tokens] = position_ids + input_buffers[ + 'block_offsets'][:batch_size, :num_blocks] = block_offsets + if q_seqlens.data_ptr() != input_buffers['q_seqlens'].data_ptr(): + input_buffers['q_seqlens'].zero_() + input_buffers['q_seqlens'][:batch_size] = q_seqlens + if kv_seqlens.data_ptr() != input_buffers['kv_seqlens'].data_ptr(): + input_buffers['kv_seqlens'].zero_() + input_buffers['kv_seqlens'][:batch_size] = kv_seqlens + input_buffers['q_start_loc'][:batch_size] = q_start_loc + + input_buffers['cu_seqlens'][:batch_size+1] = cu_seqlens + input_buffers['kv_start_indices'][:num_tokens] = kv_start_indices[:num_tokens] + + if inputs_embeds is not None: + emb_size = inputs_embeds.size(-1) + if 'inputs_embeds' not in input_buffers: + max_num_tokens = input_buffers['input_ids'].size(-1) + input_buffers['inputs_embeds'] = inputs_embeds.new_zeros( + 1, max_num_tokens, emb_size) + input_buffers['inputs_embeds'][:, :num_tokens] = inputs_embeds + + # create inputs + new_batch_size = next_power_of_2(batch_size) + new_num_tokens = next_power_of_2(num_tokens) + + attn_metadata.block_offsets = input_buffers[ + 'block_offsets'][:new_batch_size] + attn_metadata.q_start_loc = input_buffers[ + 'q_start_loc'][:new_batch_size] + attn_metadata.q_seqlens = input_buffers['q_seqlens'][:new_batch_size] + attn_metadata.kv_seqlens = input_buffers['kv_seqlens'][:new_batch_size] + + attn_metadata.cu_seqlens = input_buffers['cu_seqlens'][:batch_size+1] + attn_metadata.kv_start_indices = input_buffers['kv_start_indices'][:new_num_tokens] + new_inputs = dict( + past_key_values=past_key_values, + attn_metadata=attn_metadata, + ) + + if is_decoding: + new_inputs['input_ids'] = input_buffers[ + 'input_ids'][:, :new_batch_size] + new_inputs['position_ids'] = input_buffers[ + 'position_ids'][:, :new_batch_size] + else: + new_inputs['input_ids'] = input_buffers['input_ids'] + new_inputs['position_ids'] = input_buffers['position_ids'] + + if inputs_embeds is not None: + if is_decoding: + new_inputs['inputs_embeds'] = input_buffers[ + 'inputs_embeds'][:, :new_batch_size] + else: + new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'] + + new_inputs.update(kwargs) + return new_inputs + + def update_Camb_context(self, graph_meta, context): + """update step context with input buffers.""" + input_buffers = graph_meta.input_buffers + local_adapter_ids = context.local_adapter_ids + if local_adapter_ids is not None: + if input_buffers['local_adapter_ids'].data_ptr( + ) != local_adapter_ids.data_ptr(): + input_buffers['local_adapter_ids'].fill_(0) + batch_size = local_adapter_ids.size(0) + input_buffers['local_adapter_ids'][:batch_size] = local_adapter_ids + context.local_adapter_ids = input_buffers['local_adapter_ids'] + context.q_seqlens = input_buffers['q_seqlens'] + context.kv_seqlens = input_buffers['kv_seqlens'] + context.q_start_loc = input_buffers['q_start_loc'] + context.cu_seqlens = input_buffers['cu_seqlens'] + context.kv_start_indices = input_buffers['kv_start_indices'] + + def forward(self, **kwargs): + """forward.""" + num_tokens = kwargs['input_ids'].size(-1) + assert self._graph is not None + self.update_Camb_buffer(self.meta, **kwargs) + context = self.ctx_mgr.current_context() + self.update_Camb_context(self.meta,context) + + self._graph.replay() + + output = self.meta.output_buffers['logits'][:, :num_tokens] + return output + + def __del__(self): + """del.""" + del self._graph + + +class CAMBGraphRunner(GraphRunner): + """CAMB graph runner.""" + + def __init__(self, model: torch.nn.Module, model_config: ModelConfig, + cache_config: CacheConfig, backend_config: BackendConfig, + device: torch.device): + super().__init__(model, model_config, cache_config, backend_config, + device) + self.max_batches = cache_config.max_batches + self.max_tokens = cache_config.max_prefill_token_num + self.num_blocks = cache_config.num_gpu_blocks + + self.enable_graph = self.check_enable_graph() + + self.graph_pool_handle = torch.cuda.graph_pool_handle() + self._runner_map: Dict[Any, CAMBSingleGraphRunner] = dict() + + def check_enable_graph(self): + """check enable graph.""" + if self.backend_config.eager_mode: + return _false + + return getattr(self.model, 'support_cuda_graph', _false) + + def get_graph_key(self, input_ids: torch.Tensor, + position_ids: torch.Tensor, past_key_values: List, + attn_metadata: Any, inputs_embeds: torch.Tensor, + **kwargs): + """get graph key.""" + context = self.ctx_mgr.current_context() + is_decoding = context.is_decoding + num_tokens = input_ids.numel() + new_num_tokens = next_power_of_2(num_tokens) + return (new_num_tokens, is_decoding) + + def __call__(self, **kwargs): + """call.""" + enable_graph = self.enable_graph(**kwargs) + + if not enable_graph: + return self.model(**kwargs) + + graph_key = self.get_graph_key(**kwargs) + max_tokens = graph_key[0] + is_decoding = graph_key[1] + if graph_key not in self._runner_map: + max_batches = max_tokens if is_decoding else self.max_batches + runner = CAMBSingleGraphRunner(self.model, + max_batches=max_batches, + max_tokens=max_tokens, + num_blocks=self.num_blocks, + is_decoding=is_decoding, + pool=self.graph_pool_handle, + device=self.device) + runner.capture(**kwargs) + self._runner_map[graph_key] = runner + else: + runner = self._runner_map[graph_key] + output = runner.forward(**kwargs) + return output + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare inputs.""" + return self.model.prepare_inputs_for_generation( + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + context=context, + ) diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py index c003983e2f..41dec1dd21 100644 --- a/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py @@ -3,8 +3,10 @@ import torch +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig from lmdeploy.utils import get_logger +from ...base import OpType from ..op_backend import DlinferOpsBackend logger = get_logger('lmdeploy') @@ -105,3 +107,12 @@ def update_step_context(cls, step_context): step_context.attn_metadata = attn_metadata return step_context + @staticmethod + def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, + cache_config: CacheConfig, + backend_config: BackendConfig, + device: torch.device): + """build graph runner.""" + from .graph_runner import CAMBGraphRunner + return CAMBGraphRunner(model, model_config, cache_config, + backend_config, device) \ No newline at end of file From 96d36a42918aecf6750d9104ade95b3910cd529c Mon Sep 17 00:00:00 2001 From: wanfengcxz <2917021186@qq.com> Date: Mon, 28 Oct 2024 22:27:15 +0800 Subject: [PATCH 09/14] refactor(camb) --- lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py index d714241ed3..be1698289f 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py @@ -15,7 +15,7 @@ def apply_rotary_pos_emb( ) -> Tuple[Tensor, Tensor]: query_states = query_states.contiguous() key_states = key_states.contiguous() - query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None, None) if q_embed is None: q_embed = query_states From fec0eb88e6228bb4f770972cecdfb99ad93ca8d3 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 29 Oct 2024 10:24:34 +0800 Subject: [PATCH 10/14] change camb graph --- .../backends/dlinfer/camb/graph_runner.py | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py index ba64313cb8..847196a536 100644 --- a/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py +++ b/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py @@ -16,19 +16,23 @@ logger = get_logger('lmdeploy') BuffType = Dict[str, Tensor] - -def next_power_of_2(n: int): - """Return the smallest power of 2 greater than or equal to n.""" - n -= 1 - n |= n >> 1 - n |= n >> 2 - n |= n >> 4 - n |= n >> 8 - n |= n >> 16 - n |= n >> 32 - n += 1 +def tensor_padding(n: int): + """currently camb seems support all n.""" + """ for a LLM model, we need to capture 2 graphs, one is prefill stage(fixed total seqLens) + and the other is decoder stage(batch is fixed)""" return n + # """Return the smallest power of 2 greater than or equal to n.""" + # n -= 1 + # n |= n >> 1 + # n |= n >> 2 + # n |= n >> 4 + # n |= n >> 8 + # n |= n >> 16 + # n |= n >> 32 + # n += 1 + # return n + def _false(*args, **kwargs): """default value of not support cuda graph.""" @@ -71,11 +75,9 @@ def capture(self, **kwargs): """capture graph.""" self.meta.input_buffers = self.make_Camb_buffers( self.meta, **kwargs) - # padded_kwargs = self.model.fill_buffers_cudagraph(self.meta, **kwargs) padded_kwargs = self.update_Camb_buffer(self.meta, **kwargs) context = self.ctx_mgr.current_context() - # self.model.update_context_cudagraph(self.meta, context) self.update_Camb_context(self.meta, context) current_stream = torch.cuda.current_stream() # warmup @@ -168,6 +170,7 @@ def update_Camb_buffer(self, graph_meta: CudaGraphMeta, input_buffers['kv_seqlens'][:batch_size] = kv_seqlens input_buffers['q_start_loc'][:batch_size] = q_start_loc + input_buffers['cu_seqlens'][:batch_size+1] = cu_seqlens input_buffers['kv_start_indices'][:num_tokens] = kv_start_indices[:num_tokens] @@ -180,8 +183,8 @@ def update_Camb_buffer(self, graph_meta: CudaGraphMeta, input_buffers['inputs_embeds'][:, :num_tokens] = inputs_embeds # create inputs - new_batch_size = next_power_of_2(batch_size) - new_num_tokens = next_power_of_2(num_tokens) + new_batch_size = tensor_padding(batch_size) + new_num_tokens = tensor_padding(num_tokens) attn_metadata.block_offsets = input_buffers[ 'block_offsets'][:new_batch_size] @@ -283,7 +286,7 @@ def get_graph_key(self, input_ids: torch.Tensor, context = self.ctx_mgr.current_context() is_decoding = context.is_decoding num_tokens = input_ids.numel() - new_num_tokens = next_power_of_2(num_tokens) + new_num_tokens = tensor_padding(num_tokens) return (new_num_tokens, is_decoding) def __call__(self, **kwargs): From b2e7df2784d6f7e66c9704e786550aca85880a94 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 31 Oct 2024 14:38:00 +0800 Subject: [PATCH 11/14] change graph --- lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py index 847196a536..69075208a3 100644 --- a/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py +++ b/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py @@ -80,6 +80,7 @@ def capture(self, **kwargs): context = self.ctx_mgr.current_context() self.update_Camb_context(self.meta, context) current_stream = torch.cuda.current_stream() + # warmup self.model(**padded_kwargs) @@ -292,13 +293,13 @@ def get_graph_key(self, input_ids: torch.Tensor, def __call__(self, **kwargs): """call.""" enable_graph = self.enable_graph(**kwargs) - - if not enable_graph: - return self.model(**kwargs) - graph_key = self.get_graph_key(**kwargs) max_tokens = graph_key[0] is_decoding = graph_key[1] + + if (not enable_graph) or (not is_decoding): + return self.model(**kwargs) + if graph_key not in self._runner_map: max_batches = max_tokens if is_decoding else self.max_batches runner = CAMBSingleGraphRunner(self.model, @@ -312,6 +313,7 @@ def __call__(self, **kwargs): self._runner_map[graph_key] = runner else: runner = self._runner_map[graph_key] + output = runner.forward(**kwargs) return output From 48b91da33b3f3dfe297c8a7ae8fbd961b67e8eef Mon Sep 17 00:00:00 2001 From: wanfengcxz <2917021186@qq.com> Date: Wed, 6 Nov 2024 17:07:28 +0800 Subject: [PATCH 12/14] support server test(camb) --- lmdeploy/cli/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index ad7a058c8f..e0675e29a5 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -367,7 +367,7 @@ def calib_search_scale(parser): @staticmethod def device(parser, default: str = 'cuda', - choices: List[str] = ['cuda', 'ascend', 'maca']): + choices: List[str] = ['cuda', 'ascend', 'maca', 'camb']): """Add argument device to parser.""" return parser.add_argument('--device', From 268d7242fccd8cf87838ae832aeb5834f4558e11 Mon Sep 17 00:00:00 2001 From: wanfengcxz <2917021186@qq.com> Date: Fri, 8 Nov 2024 16:11:55 +0800 Subject: [PATCH 13/14] change --- lmdeploy/pytorch/kernels/dlinfer/pagedattention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index 98996e8e5a..0fd50ce932 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -19,7 +19,7 @@ def prefill_attention( block_size: int, attn_mask: Sequence[Optional[Tensor]], is_unpaged_prefill: Optional[bool], -): +) -> Tensor: if is_unpaged_prefill: return ext_ops.prefill_attention( query_states, From 68a72e332b67ac98357f42887dbdfd0531ee740d Mon Sep 17 00:00:00 2001 From: wanfengcxz <2917021186@qq.com> Date: Fri, 15 Nov 2024 15:48:59 +0800 Subject: [PATCH 14/14] support camb graph with dummpy input --- .../backends/dlinfer/camb/graph_runner.py | 86 ++++++++----------- 1 file changed, 34 insertions(+), 52 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py index 69075208a3..e746fa7ce7 100644 --- a/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py +++ b/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py @@ -16,22 +16,9 @@ logger = get_logger('lmdeploy') BuffType = Dict[str, Tensor] -def tensor_padding(n: int): - """currently camb seems support all n.""" - """ for a LLM model, we need to capture 2 graphs, one is prefill stage(fixed total seqLens) - and the other is decoder stage(batch is fixed)""" - return n - - # """Return the smallest power of 2 greater than or equal to n.""" - # n -= 1 - # n |= n >> 1 - # n |= n >> 2 - # n |= n >> 4 - # n |= n >> 8 - # n |= n >> 16 - # n |= n >> 32 - # n += 1 - # return n + +def round_up_to_multiple_of_8(n: int): + return (n + 7) // 8 * 8 def _false(*args, **kwargs): @@ -69,25 +56,25 @@ def __init__( self.num_blocks = num_blocks self.is_decoding = is_decoding self.pool = pool - self._graph: torch.cuda.CUDAGraph = None + self._graph: torch.mlu.CUDAGraph = None def capture(self, **kwargs): """capture graph.""" - self.meta.input_buffers = self.make_Camb_buffers( + self.meta.input_buffers = self.make_camb_buffers( self.meta, **kwargs) - padded_kwargs = self.update_Camb_buffer(self.meta, **kwargs) + padded_kwargs = self.update_camb_buffer(self.meta, **kwargs) context = self.ctx_mgr.current_context() - self.update_Camb_context(self.meta, context) - current_stream = torch.cuda.current_stream() + self.update_camb_context(self.meta, context) + current_stream = torch.mlu.current_stream() # warmup - self.model(**padded_kwargs) + output = self.model(**padded_kwargs) - self._graph = torch.cuda.CUDAGraph() + self._graph = torch.mlu.CUDAGraph() # unsafe kernel call in other thread might invalid the capture # so we set thread_safe capture mode here. - with torch.cuda.graph(self._graph, + with torch.mlu.graph(self._graph, pool=self.pool, stream=current_stream, capture_error_mode='thread_local'): @@ -97,7 +84,7 @@ def capture(self, **kwargs): self.meta.output_buffers = output_buffers return output - def make_Camb_buffers(self, graph_meta: CudaGraphMeta, *args, + def make_camb_buffers(self, graph_meta: CudaGraphMeta, *args, **kwargs) -> BuffType: """make cudagraph buffers from forward inputs.""" max_batches = graph_meta.max_batchs @@ -110,36 +97,36 @@ def make_Camb_buffers(self, graph_meta: CudaGraphMeta, *args, max_tokens, dtype=torch.int32, device=device) - input_buffers['position_ids'] = torch.zeros((1, max_tokens), + input_buffers['position_ids'] = torch.ones((1, max_tokens), dtype=torch.int32, device=device) input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), dtype=torch.int32, device=device) - input_buffers['q_start_loc'] = torch.zeros(max_batches, + + input_buffers['q_start_loc'] = torch.arange(max_batches, dtype=torch.int32, device=device) - input_buffers['q_seqlens'] = torch.zeros(max_batches, + + input_buffers['q_seqlens'] = torch.ones(max_batches, dtype=torch.int32, device=device) - input_buffers['kv_seqlens'] = torch.zeros(max_batches, + + input_buffers['kv_seqlens'] = torch.ones(max_batches, dtype=torch.int32, device=device) - input_buffers['cu_seqlens'] = torch.zeros(max_batches+1, - dtype=torch.int32, - device=device) - input_buffers['kv_start_indices'] = torch.ones(max_batches*max_tokens, + input_buffers['kv_start_indices'] = -torch.ones((max_batches*max_tokens), dtype=torch.int32, - device=device) * 512 - + device=device) + input_buffers['local_adapter_ids'] = torch.zeros(max_batches, dtype=torch.int32, device=device) return input_buffers - def update_Camb_buffer(self, graph_meta: CudaGraphMeta, + def update_camb_buffer(self, graph_meta: CudaGraphMeta, input_ids: Tensor, position_ids: Tensor, past_key_values: List, attn_metadata: Any, inputs_embeds: Tensor, @@ -150,8 +137,6 @@ def update_Camb_buffer(self, graph_meta: CudaGraphMeta, q_start_loc: Tensor = attn_metadata.q_start_loc q_seqlens: Tensor = attn_metadata.q_seqlens kv_seqlens: Tensor = attn_metadata.kv_seqlens - - cu_seqlens: Tensor = attn_metadata.cu_seqlens kv_start_indices: Tensor = attn_metadata.kv_start_indices input_buffers: BuffType = graph_meta.input_buffers @@ -163,16 +148,15 @@ def update_Camb_buffer(self, graph_meta: CudaGraphMeta, input_buffers['position_ids'][:, :num_tokens] = position_ids input_buffers[ 'block_offsets'][:batch_size, :num_blocks] = block_offsets - if q_seqlens.data_ptr() != input_buffers['q_seqlens'].data_ptr(): - input_buffers['q_seqlens'].zero_() + # if q_seqlens.data_ptr() != input_buffers['q_seqlens'].data_ptr(): + # input_buffers['q_seqlens'].zero_() input_buffers['q_seqlens'][:batch_size] = q_seqlens - if kv_seqlens.data_ptr() != input_buffers['kv_seqlens'].data_ptr(): - input_buffers['kv_seqlens'].zero_() + # if kv_seqlens.data_ptr() != input_buffers['kv_seqlens'].data_ptr(): + # input_buffers['kv_seqlens'].zero_() input_buffers['kv_seqlens'][:batch_size] = kv_seqlens input_buffers['q_start_loc'][:batch_size] = q_start_loc - input_buffers['cu_seqlens'][:batch_size+1] = cu_seqlens input_buffers['kv_start_indices'][:num_tokens] = kv_start_indices[:num_tokens] if inputs_embeds is not None: @@ -184,8 +168,8 @@ def update_Camb_buffer(self, graph_meta: CudaGraphMeta, input_buffers['inputs_embeds'][:, :num_tokens] = inputs_embeds # create inputs - new_batch_size = tensor_padding(batch_size) - new_num_tokens = tensor_padding(num_tokens) + new_batch_size = round_up_to_multiple_of_8(batch_size) + new_num_tokens = round_up_to_multiple_of_8(num_tokens) attn_metadata.block_offsets = input_buffers[ 'block_offsets'][:new_batch_size] @@ -194,7 +178,6 @@ def update_Camb_buffer(self, graph_meta: CudaGraphMeta, attn_metadata.q_seqlens = input_buffers['q_seqlens'][:new_batch_size] attn_metadata.kv_seqlens = input_buffers['kv_seqlens'][:new_batch_size] - attn_metadata.cu_seqlens = input_buffers['cu_seqlens'][:batch_size+1] attn_metadata.kv_start_indices = input_buffers['kv_start_indices'][:new_num_tokens] new_inputs = dict( past_key_values=past_key_values, @@ -220,7 +203,7 @@ def update_Camb_buffer(self, graph_meta: CudaGraphMeta, new_inputs.update(kwargs) return new_inputs - def update_Camb_context(self, graph_meta, context): + def update_camb_context(self, graph_meta, context): """update step context with input buffers.""" input_buffers = graph_meta.input_buffers local_adapter_ids = context.local_adapter_ids @@ -234,16 +217,15 @@ def update_Camb_context(self, graph_meta, context): context.q_seqlens = input_buffers['q_seqlens'] context.kv_seqlens = input_buffers['kv_seqlens'] context.q_start_loc = input_buffers['q_start_loc'] - context.cu_seqlens = input_buffers['cu_seqlens'] context.kv_start_indices = input_buffers['kv_start_indices'] def forward(self, **kwargs): """forward.""" num_tokens = kwargs['input_ids'].size(-1) assert self._graph is not None - self.update_Camb_buffer(self.meta, **kwargs) + self.update_camb_buffer(self.meta, **kwargs) context = self.ctx_mgr.current_context() - self.update_Camb_context(self.meta,context) + self.update_camb_context(self.meta,context) self._graph.replay() @@ -269,7 +251,7 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, self.enable_graph = self.check_enable_graph() - self.graph_pool_handle = torch.cuda.graph_pool_handle() + self.graph_pool_handle = torch.mlu.graph_pool_handle() self._runner_map: Dict[Any, CAMBSingleGraphRunner] = dict() def check_enable_graph(self): @@ -287,7 +269,7 @@ def get_graph_key(self, input_ids: torch.Tensor, context = self.ctx_mgr.current_context() is_decoding = context.is_decoding num_tokens = input_ids.numel() - new_num_tokens = tensor_padding(num_tokens) + new_num_tokens = round_up_to_multiple_of_8(num_tokens) return (new_num_tokens, is_decoding) def __call__(self, **kwargs):