diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index ad7a058c8f..e0675e29a5 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -367,7 +367,7 @@ def calib_search_scale(parser): @staticmethod def device(parser, default: str = 'cuda', - choices: List[str] = ['cuda', 'ascend', 'maca']): + choices: List[str] = ['cuda', 'ascend', 'maca', 'camb']): """Add argument device to parser.""" return parser.add_argument('--device', diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 90823598ea..2f4b7eb6a4 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -291,7 +291,7 @@ def __post_init__(self): assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks' assert self.quant_policy in (0, 4, 8), 'invalid quant_policy' assert self.device_type in [ - 'cuda', 'ascend', 'maca' + 'cuda', 'ascend', 'maca', 'camb' ], (f'invalid device_type: {self.device_type}') if self.quant_policy > 0 and self.device_type != 'cuda': assert False, 'kv cache quantization only works for CUDA.' diff --git a/lmdeploy/pytorch/backends/dlinfer/__init__.py b/lmdeploy/pytorch/backends/dlinfer/__init__.py index af3ccff085..b1d2382936 100644 --- a/lmdeploy/pytorch/backends/dlinfer/__init__.py +++ b/lmdeploy/pytorch/backends/dlinfer/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ascend import AscendOpsBackend # noqa: F401 from .maca import MacaOpsBackend # noqa: F401 +from .camb import CambOpsBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py b/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py index c2bc1f7dce..5d5c6b034a 100644 --- a/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py +++ b/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py @@ -6,7 +6,6 @@ from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl - class DlinferApplyRotaryEmbImpl(ApplyRotaryEmbImpl): """Apply rotary embedding implementation.""" diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 0d666c9130..c667c86387 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -15,7 +15,9 @@ class DlinferAttentionMetadata(AttentionMetadata): is_unpaged_prefill: Optional[bool] = None max_q_seq_len: int = 1 max_kv_seq_len: int = 1 - + cu_seqlens: Optional[Tensor] = None + is_flash_attn_support_inplace: bool = True + is_mock_q_start_loc: bool = False class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): """dlinfer attention implementation.""" @@ -74,11 +76,20 @@ def forward( is_unpaged_prefill = attn_metadata.is_unpaged_prefill max_q_seq_len = attn_metadata.max_q_seq_len max_kv_seq_len = attn_metadata.max_kv_seq_len + cu_seqlens = attn_metadata.cu_seqlens + is_mock_q_start_loc = attn_metadata.is_mock_q_start_loc # fill kv cache k_cache, v_cache = self.fill_kv_cache(key, value, k_cache, v_cache, kv_start_indices) + if is_unpaged_prefill: + inplace = inplace if attn_metadata.is_flash_attn_support_inplace \ + else False + + if is_mock_q_start_loc: + q_start_loc = cu_seqlens + if inplace: attn_output = query[..., :self.v_head_size] else: diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/__init__.py b/lmdeploy/pytorch/backends/dlinfer/camb/__init__.py new file mode 100644 index 0000000000..897495c209 --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/camb/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .op_backend import CambOpsBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py new file mode 100644 index 0000000000..e746fa7ce7 --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/camb/graph_runner.py @@ -0,0 +1,313 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Tuple + +import torch +import torch_mlu +from torch_mlu.utils.model_transfer import transfer +from torch import Tensor + +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.pytorch.model_inputs import StepContext +from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta +from lmdeploy.utils import get_logger + +from ...graph_runner import GraphRunner + +logger = get_logger('lmdeploy') + +BuffType = Dict[str, Tensor] + +def round_up_to_multiple_of_8(n: int): + return (n + 7) // 8 * 8 + + +def _false(*args, **kwargs): + """default value of not support cuda graph.""" + return False + + +class CAMBSingleGraphRunner: + """camb single graph runner.""" + + def __init__( + self, + model: torch.nn.Module, + max_batches: int, + max_tokens: int, + num_blocks: int, + is_decoding: bool, + pool: Tuple[int, int], + device: torch.device, + ): + self.model = model + self.ctx_mgr = model.ctx_mgr + self.meta = CudaGraphMeta( + max_batchs=max_batches, + max_tokens=max_tokens, + num_blocks=num_blocks, + is_decoding=is_decoding, + device=device, + input_buffers=dict(), + output_buffers=dict(), + ) + self.device = device + self.max_batches = max_batches + self.max_tokens = max_tokens + self.num_blocks = num_blocks + self.is_decoding = is_decoding + self.pool = pool + self._graph: torch.mlu.CUDAGraph = None + + def capture(self, **kwargs): + """capture graph.""" + self.meta.input_buffers = self.make_camb_buffers( + self.meta, **kwargs) + padded_kwargs = self.update_camb_buffer(self.meta, **kwargs) + + context = self.ctx_mgr.current_context() + self.update_camb_context(self.meta, context) + current_stream = torch.mlu.current_stream() + + # warmup + output = self.model(**padded_kwargs) + + self._graph = torch.mlu.CUDAGraph() + # unsafe kernel call in other thread might invalid the capture + # so we set thread_safe capture mode here. + with torch.mlu.graph(self._graph, + pool=self.pool, + stream=current_stream, + capture_error_mode='thread_local'): + output = self.model(**padded_kwargs) + + output_buffers = dict(logits=output) + self.meta.output_buffers = output_buffers + return output + + def make_camb_buffers(self, graph_meta: CudaGraphMeta, *args, + **kwargs) -> BuffType: + """make cudagraph buffers from forward inputs.""" + max_batches = graph_meta.max_batchs + max_tokens = graph_meta.max_tokens + num_blocks = graph_meta.num_blocks + device = graph_meta.device + + input_buffers: BuffType = dict() + input_buffers['input_ids'] = torch.zeros(1, + max_tokens, + dtype=torch.int32, + device=device) + input_buffers['position_ids'] = torch.ones((1, max_tokens), + dtype=torch.int32, + device=device) + + input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), + dtype=torch.int32, + device=device) + + input_buffers['q_start_loc'] = torch.arange(max_batches, + dtype=torch.int32, + device=device) + + input_buffers['q_seqlens'] = torch.ones(max_batches, + dtype=torch.int32, + device=device) + + input_buffers['kv_seqlens'] = torch.ones(max_batches, + dtype=torch.int32, + device=device) + + input_buffers['kv_start_indices'] = -torch.ones((max_batches*max_tokens), + dtype=torch.int32, + device=device) + + input_buffers['local_adapter_ids'] = torch.zeros(max_batches, + dtype=torch.int32, + device=device) + return input_buffers + + def update_camb_buffer(self, graph_meta: CudaGraphMeta, + input_ids: Tensor, position_ids: Tensor, + past_key_values: List, attn_metadata: Any, + inputs_embeds: Tensor, + **kwargs) -> Dict[str, Tensor]: + """fill cudagraph buffers from forward inputs.""" + is_decoding = graph_meta.is_decoding + block_offsets: Tensor = attn_metadata.block_offsets + q_start_loc: Tensor = attn_metadata.q_start_loc + q_seqlens: Tensor = attn_metadata.q_seqlens + kv_seqlens: Tensor = attn_metadata.kv_seqlens + kv_start_indices: Tensor = attn_metadata.kv_start_indices + + input_buffers: BuffType = graph_meta.input_buffers + + batch_size, num_blocks = block_offsets.size() + num_tokens = input_ids.size(-1) + # fill buffer + input_buffers['input_ids'][:, :num_tokens] = input_ids + input_buffers['position_ids'][:, :num_tokens] = position_ids + input_buffers[ + 'block_offsets'][:batch_size, :num_blocks] = block_offsets + # if q_seqlens.data_ptr() != input_buffers['q_seqlens'].data_ptr(): + # input_buffers['q_seqlens'].zero_() + input_buffers['q_seqlens'][:batch_size] = q_seqlens + # if kv_seqlens.data_ptr() != input_buffers['kv_seqlens'].data_ptr(): + # input_buffers['kv_seqlens'].zero_() + input_buffers['kv_seqlens'][:batch_size] = kv_seqlens + input_buffers['q_start_loc'][:batch_size] = q_start_loc + + + input_buffers['kv_start_indices'][:num_tokens] = kv_start_indices[:num_tokens] + + if inputs_embeds is not None: + emb_size = inputs_embeds.size(-1) + if 'inputs_embeds' not in input_buffers: + max_num_tokens = input_buffers['input_ids'].size(-1) + input_buffers['inputs_embeds'] = inputs_embeds.new_zeros( + 1, max_num_tokens, emb_size) + input_buffers['inputs_embeds'][:, :num_tokens] = inputs_embeds + + # create inputs + new_batch_size = round_up_to_multiple_of_8(batch_size) + new_num_tokens = round_up_to_multiple_of_8(num_tokens) + + attn_metadata.block_offsets = input_buffers[ + 'block_offsets'][:new_batch_size] + attn_metadata.q_start_loc = input_buffers[ + 'q_start_loc'][:new_batch_size] + attn_metadata.q_seqlens = input_buffers['q_seqlens'][:new_batch_size] + attn_metadata.kv_seqlens = input_buffers['kv_seqlens'][:new_batch_size] + + attn_metadata.kv_start_indices = input_buffers['kv_start_indices'][:new_num_tokens] + new_inputs = dict( + past_key_values=past_key_values, + attn_metadata=attn_metadata, + ) + + if is_decoding: + new_inputs['input_ids'] = input_buffers[ + 'input_ids'][:, :new_batch_size] + new_inputs['position_ids'] = input_buffers[ + 'position_ids'][:, :new_batch_size] + else: + new_inputs['input_ids'] = input_buffers['input_ids'] + new_inputs['position_ids'] = input_buffers['position_ids'] + + if inputs_embeds is not None: + if is_decoding: + new_inputs['inputs_embeds'] = input_buffers[ + 'inputs_embeds'][:, :new_batch_size] + else: + new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'] + + new_inputs.update(kwargs) + return new_inputs + + def update_camb_context(self, graph_meta, context): + """update step context with input buffers.""" + input_buffers = graph_meta.input_buffers + local_adapter_ids = context.local_adapter_ids + if local_adapter_ids is not None: + if input_buffers['local_adapter_ids'].data_ptr( + ) != local_adapter_ids.data_ptr(): + input_buffers['local_adapter_ids'].fill_(0) + batch_size = local_adapter_ids.size(0) + input_buffers['local_adapter_ids'][:batch_size] = local_adapter_ids + context.local_adapter_ids = input_buffers['local_adapter_ids'] + context.q_seqlens = input_buffers['q_seqlens'] + context.kv_seqlens = input_buffers['kv_seqlens'] + context.q_start_loc = input_buffers['q_start_loc'] + context.kv_start_indices = input_buffers['kv_start_indices'] + + def forward(self, **kwargs): + """forward.""" + num_tokens = kwargs['input_ids'].size(-1) + assert self._graph is not None + self.update_camb_buffer(self.meta, **kwargs) + context = self.ctx_mgr.current_context() + self.update_camb_context(self.meta,context) + + self._graph.replay() + + output = self.meta.output_buffers['logits'][:, :num_tokens] + return output + + def __del__(self): + """del.""" + del self._graph + + +class CAMBGraphRunner(GraphRunner): + """CAMB graph runner.""" + + def __init__(self, model: torch.nn.Module, model_config: ModelConfig, + cache_config: CacheConfig, backend_config: BackendConfig, + device: torch.device): + super().__init__(model, model_config, cache_config, backend_config, + device) + self.max_batches = cache_config.max_batches + self.max_tokens = cache_config.max_prefill_token_num + self.num_blocks = cache_config.num_gpu_blocks + + self.enable_graph = self.check_enable_graph() + + self.graph_pool_handle = torch.mlu.graph_pool_handle() + self._runner_map: Dict[Any, CAMBSingleGraphRunner] = dict() + + def check_enable_graph(self): + """check enable graph.""" + if self.backend_config.eager_mode: + return _false + + return getattr(self.model, 'support_cuda_graph', _false) + + def get_graph_key(self, input_ids: torch.Tensor, + position_ids: torch.Tensor, past_key_values: List, + attn_metadata: Any, inputs_embeds: torch.Tensor, + **kwargs): + """get graph key.""" + context = self.ctx_mgr.current_context() + is_decoding = context.is_decoding + num_tokens = input_ids.numel() + new_num_tokens = round_up_to_multiple_of_8(num_tokens) + return (new_num_tokens, is_decoding) + + def __call__(self, **kwargs): + """call.""" + enable_graph = self.enable_graph(**kwargs) + graph_key = self.get_graph_key(**kwargs) + max_tokens = graph_key[0] + is_decoding = graph_key[1] + + if (not enable_graph) or (not is_decoding): + return self.model(**kwargs) + + if graph_key not in self._runner_map: + max_batches = max_tokens if is_decoding else self.max_batches + runner = CAMBSingleGraphRunner(self.model, + max_batches=max_batches, + max_tokens=max_tokens, + num_blocks=self.num_blocks, + is_decoding=is_decoding, + pool=self.graph_pool_handle, + device=self.device) + runner.capture(**kwargs) + self._runner_map[graph_key] = runner + else: + runner = self._runner_map[graph_key] + + output = runner.forward(**kwargs) + return output + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare inputs.""" + return self.model.prepare_inputs_for_generation( + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + context=context, + ) diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py new file mode 100644 index 0000000000..41dec1dd21 --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch + +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.utils import get_logger + +from ...base import OpType +from ..op_backend import DlinferOpsBackend + +logger = get_logger('lmdeploy') + + +class CambOpsBackend(DlinferOpsBackend): + """camb layer backend.""" + + @staticmethod + def get_name() -> str: + """backend name.""" + return 'camb' + + @staticmethod + def get_k_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + num_heads, + block_size, + head_size, + ) + + @staticmethod + def get_v_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + num_heads, + block_size, + head_size, + ) + + @classmethod + def update_step_context(cls, step_context): + """update step context.""" + kv_start_indices, attention_mask = [], [] + _, _, block_size, _ = step_context.kv_caches[0][0].shape + device = step_context.block_offsets.device + batch_size = step_context.q_start_loc.shape[0] + + is_unpaged_prefill = False + q_start_loc = step_context.q_start_loc + q_seqlens = step_context.q_seqlens + kv_seqlens = step_context.kv_seqlens.to(torch.int32) + max_q_seq_len = torch.max(q_seqlens).cpu().item() + max_kv_seq_len = torch.max(kv_seqlens).cpu().item() + + cu_seqlens = torch.zeros(batch_size+1, dtype=torch.int32, device=device) + cu_seqlens[:-1] = step_context.q_start_loc + cu_seqlens[-1] = step_context.q_seqlens.sum() + + if not step_context.is_decoding: + is_unpaged_prefill = \ + all((step_context.q_seqlens == + step_context.kv_seqlens).tolist()) + + for i in range(batch_size): + q_seq_len = int(step_context.q_seqlens[i]) + kv_seq_len = int(step_context.kv_seqlens[i]) + history_length = kv_seq_len - q_seq_len + block_idx = history_length // block_size + block_loc = step_context.block_offsets[i][block_idx] + token_loc = history_length % block_size + for j in range(q_seq_len): + kv_start_indices.append(block_loc * block_size + token_loc) + if j == q_seq_len - 1: + break + token_loc = (token_loc + 1) % block_size + block_idx = block_idx if token_loc else block_idx + 1 + block_loc = step_context.block_offsets[i][block_idx] + kv_start_indices = torch.tensor(kv_start_indices, device=device, dtype=torch.int32) + + attn_meta_cls = cls.get_attention_metadata_cls() + attn_metadata = attn_meta_cls( + step_context.is_decoding, + step_context.block_offsets.to(torch.int32), + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_seqlens=kv_seqlens, + kv_start_indices=kv_start_indices, + block_size=block_size, + attention_mask=None, + is_unpaged_prefill=is_unpaged_prefill, + max_q_seq_len=max_q_seq_len, + max_kv_seq_len=max_kv_seq_len, + cu_seqlens=cu_seqlens, + is_flash_attn_support_inplace=False, + is_mock_q_start_loc=True, + ) + + step_context.attn_metadata = attn_metadata + return step_context + + @staticmethod + def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, + cache_config: CacheConfig, + backend_config: BackendConfig, + device: torch.device): + """build graph runner.""" + from .graph_runner import CAMBGraphRunner + return CAMBGraphRunner(model, model_config, cache_config, + backend_config, device) \ No newline at end of file diff --git a/lmdeploy/pytorch/backends/selector.py b/lmdeploy/pytorch/backends/selector.py index 987730a981..4db73fa370 100644 --- a/lmdeploy/pytorch/backends/selector.py +++ b/lmdeploy/pytorch/backends/selector.py @@ -18,5 +18,8 @@ def get_backend(): if device_type == 'maca': from .dlinfer import MacaOpsBackend return MacaOpsBackend + if device_type == 'camb': + from .dlinfer import CambOpsBackend + return CambOpsBackend else: raise RuntimeError(f'Unsupported device type: {device_type}') diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 7d72438224..2d3ad73439 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -33,6 +33,7 @@ def try_import_deeplink(device_type: str): 'ascend', 'npu', 'maca', + 'camb', ] if device_type in deeplink_device_type_list: logger = get_logger('lmdeploy') diff --git a/lmdeploy/pytorch/kernels/dlinfer/__init__.py b/lmdeploy/pytorch/kernels/dlinfer/__init__.py index 8f86f0019a..0014a88e4d 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/__init__.py +++ b/lmdeploy/pytorch/kernels/dlinfer/__init__.py @@ -8,6 +8,7 @@ from .moe_gating_topk_softmax import moe_gating_topk_softmax from .pagedattention import paged_attention_fwd from .rms_norm import rms_norm +from .activation import silu_and_mul __all__ = [ 'rms_norm', @@ -19,4 +20,5 @@ 'linear', 'moe_gating_topk_softmax', 'multinomial_sampling', + 'silu_and_mul', ] diff --git a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py index 0f13f3f38c..be1698289f 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py @@ -15,24 +15,16 @@ def apply_rotary_pos_emb( ) -> Tuple[Tensor, Tensor]: query_states = query_states.contiguous() key_states = key_states.contiguous() - bs = query_states.shape[0] - query_states_reshaped = query_states.unsqueeze(0) - key_states_reshaped = key_states.unsqueeze(0) - cos_reshaped = cos.reshape(1, bs, 1, -1) - sin_reshaped = sin.reshape(1, bs, 1, -1) - query_states_reshaped, key_states_reshaped = \ - ext_ops.apply_rotary_pos_emb(query_states_reshaped, - key_states_reshaped, - cos_reshaped, sin_reshaped, - None, None) + query_states, key_states = ext_ops.apply_rotary_pos_emb(query_states, key_states, cos, sin, None, None) + if q_embed is None: - q_embed = query_states_reshaped.view(query_states.shape) + q_embed = query_states elif q_embed is not query_states: - q_embed.copy_(query_states_reshaped.view(query_states.shape)) + q_embed.copy_(query_states) if k_embed is None: - k_embed = key_states_reshaped.view(key_states.shape) + k_embed = key_states elif k_embed is not key_states: - k_embed.copy_(key_states_reshaped.view(key_states.shape)) + k_embed.copy_(key_states) return q_embed, k_embed diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index 21c72074a4..0fd50ce932 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -20,9 +20,6 @@ def prefill_attention( attn_mask: Sequence[Optional[Tensor]], is_unpaged_prefill: Optional[bool], ) -> Tensor: - num_q_heads = query_states.shape[1] - num_kv_heads = value_states.shape[1] - if is_unpaged_prefill: return ext_ops.prefill_attention( query_states, @@ -31,8 +28,6 @@ def prefill_attention( q_start_loc, q_seq_len, max_q_seq_len, - num_q_heads, - num_kv_heads, attn_mask, attn_output=attn_output, ) @@ -55,8 +50,6 @@ def prefill_attention( def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, max_kv_seq_len, block_offsets, block_size): - num_q_heads, q_head_dim = q.shape[1:3] - num_kv_heads = k_cache.shape[-1] // q_head_dim return ext_ops.paged_decode_attention( q, k_cache, @@ -65,8 +58,6 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, block_size, kv_seq_len, max_kv_seq_len, - num_q_heads, - num_kv_heads, attn_output=attn_output, ) diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 497090afc5..b9434fcd05 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -439,3 +439,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) + diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index fbdd374f80..c065e5ff76 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -332,7 +332,7 @@ def get_max_batch_size(device_type: str): Args: device_type (str): the type of device """ - assert device_type in ['cuda', 'ascend', 'maca'] + assert device_type in ['cuda', 'ascend', 'maca', 'camb'] if device_type == 'cuda': max_batch_size_map = { 'a100': 256, @@ -352,6 +352,8 @@ def get_max_batch_size(device_type: str): return 16 elif device_type == 'maca': return 128 + elif device_type == 'camb': + return 16 def is_bf16_supported(device_type: str = 'cuda'): @@ -387,5 +389,7 @@ def is_bf16_supported(device_type: str = 'cuda'): # return False elif device_type == 'maca': return True + elif device_type == 'camb': + return True else: return False