diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 18d2d489..f67bf6f8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -12,7 +12,7 @@ on: env: CI_PATH: "${HOME}/GitHub/${{ github.repository }}/${GITHUB_RUN_NUMBER}" LMDEPLOY_PATH: "${HOME}/GitHub/lmdeploy" - LMDEPLOY_COMMIT_OR_BRANCH: 'main' + LMDEPLOY_COMMIT_OR_BRANCH: 'refactor_code' REPORT_DIR: "${HOME}/GitHub/ci_log/test_reports" TEST_LMDEPLOY_E2E_LOG_PATH: "${HOME}/Github/ci_log/logs" TEST_LMDEPLOY_E2E_MODEL_PATH: "${HOME}/Github/model" @@ -74,7 +74,7 @@ jobs: - name: Clone lmdeploy run: | set -ex - git clone https://github.com/InternLM/lmdeploy.git ${{ env.LMDEPLOY_PATH }} + git clone https://github.com/DeepLink-org/lmdeploy.git ${{ env.LMDEPLOY_PATH }} cd ${{ env.LMDEPLOY_PATH }} && git checkout ${{ env.LMDEPLOY_COMMIT_OR_BRANCH }} # git apply ${{env.CI_PATH }}/.github/ci/fix-exit-multi-npu.patch diff --git a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py index 39fc830d..2c25810f 100644 --- a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +++ b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py @@ -14,7 +14,6 @@ from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager from lmdeploy.pytorch.backends.graph_runner import GraphRunner -from lmdeploy.pytorch.backends.cuda import graph_runner from lmdeploy.utils import get_logger @@ -55,7 +54,7 @@ def AscendCudaGraphMixin_make_buffers_cudagraph( ) input_buffers["kv_start_indices"] = -torch.ones( - (max_batches), dtype=torch.int64, device=device + (max_batches), dtype=torch.int32, device=device ) return input_buffers @@ -367,6 +366,7 @@ def __call__(self, **kwargs): ) AscendGraphRunner.capturing = True runner.capture(**kwargs) + AscendGraphRunner.capturing = False self._runner_map[graph_key] = runner else: runner = self._runner_map[graph_key] @@ -424,9 +424,6 @@ def get_capture_batch_sizes(self) -> List[int]: return _get_capture_batch_size_impl(self.cache_config.max_batches) -graph_runner.CUDAGraphRunner = AscendGraphRunner - - @dataclass class GraphParams: events: dict[int, list[torch.npu.ExternalEvent]] diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index 98fa4402..e73104b3 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -95,15 +95,17 @@ def prefill_attention( value = value.contiguous() scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(query.shape[-1]) if len(attn_mask): - mask = attn_mask[0].to(query.dtype) + mask = attn_mask[0] else: - mask = torch.logical_not( - torch.tril( - torch.ones( - max_q_seq_len, max_q_seq_len, dtype=torch.bool, device=query.device - ) + mask = torch.triu( + torch.ones( + max_q_seq_len, + max_q_seq_len, + dtype=query.dtype, + device=query.device, + diagonal=1, ) - ).to(query.dtype) + ) q_seq_len = q_seq_len.cpu() if SocVersion.is_Ascend910(): torch.ops.atb._npu_flash_attention( @@ -231,7 +233,7 @@ def quant_int8(x, x_scale, x_offset): value=value, key_cache=key_cache, value_cache=value_cache, - slot_indices=kv_indices.to(torch.int32), + slot_indices=kv_indices, ) return key_cache, value_cache @@ -272,8 +274,6 @@ def paged_decode_attention( raise RuntimeError( "paged_decode_attention does not " "support alibi_slopes yet" ) - if isinstance(block_table, torch.Tensor) and block_table.dtype != torch.int32: - block_table = block_table.to(torch.int32) query = query.contiguous() attn_output = attn_output.contiguous() @@ -358,10 +358,6 @@ def paged_prefill_attention( "paged_decode_attention does not " "support alibi_slopes yet" ) - if block_table.dtype != torch.int32: - block_table = block_table.to(torch.int32) - - kv_seq_len_list = kv_seq_len.tolist() scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(query.shape[-1]) query = query.contiguous().view(query.shape[0], 1, -1) block_num = key_cache.size(0) @@ -369,46 +365,19 @@ def paged_prefill_attention( value_cache = value_cache.view(block_num, block_size, -1) attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( - query, - key_cache, - value_cache, - pse_shift=None, + query=query, + key=key_cache, + value=value_cache, atten_mask=attn_mask[0], - actual_seq_lengths=kv_seq_len_list, - actual_seq_lengths_kv=kv_seq_len, - dequant_scale1=None, - quant_scale1=None, - dequant_scale2=None, - quant_scale2=None, - quant_offset2=None, - antiquant_scale=kv_scales, - antiquant_offset=kv_zeros, block_table=block_table, - query_padding_size=None, - kv_padding_size=None, - key_antiquant_scale=None, - key_antiquant_offset=None, - value_antiquant_scale=None, - value_antiquant_offset=None, - key_shared_prefix=None, - value_shared_prefix=None, - actual_shared_prefix_len=None, - query_rope=None, - key_rope=None, - key_rope_antiquant_scale=None, - num_heads=num_q_heads, - scale=scale_value, - pre_tokens=2147483647, - next_tokens=2147483647, input_layout="BSH", + block_size=block_size, + actual_seq_lengths=q_seq_len, + actual_seq_lengths_kv=kv_seq_len, num_key_value_heads=num_kv_heads, + num_heads=num_q_heads, + scale=scale_value, sparse_mode=0, - inner_precise=1, - block_size=block_size, - antiquant_mode=0, - softmax_lse_flag=False, - key_antiquant_mode=0, - value_antiquant_mode=0, ) return attn_output