Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ on:
env:
CI_PATH: "${HOME}/GitHub/${{ github.repository }}/${GITHUB_RUN_NUMBER}"
LMDEPLOY_PATH: "${HOME}/GitHub/lmdeploy"
LMDEPLOY_COMMIT_OR_BRANCH: 'main'
LMDEPLOY_COMMIT_OR_BRANCH: 'refactor_code'
REPORT_DIR: "${HOME}/GitHub/ci_log/test_reports"
TEST_LMDEPLOY_E2E_LOG_PATH: "${HOME}/Github/ci_log/logs"
TEST_LMDEPLOY_E2E_MODEL_PATH: "${HOME}/Github/model"
Expand Down Expand Up @@ -74,7 +74,7 @@ jobs:
- name: Clone lmdeploy
run: |
set -ex
git clone https://github.com/InternLM/lmdeploy.git ${{ env.LMDEPLOY_PATH }}
git clone https://github.com/DeepLink-org/lmdeploy.git ${{ env.LMDEPLOY_PATH }}
cd ${{ env.LMDEPLOY_PATH }} && git checkout ${{ env.LMDEPLOY_COMMIT_OR_BRANCH }}
# git apply ${{env.CI_PATH }}/.github/ci/fix-exit-multi-npu.patch

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager
from lmdeploy.pytorch.backends.graph_runner import GraphRunner
from lmdeploy.pytorch.backends.cuda import graph_runner

from lmdeploy.utils import get_logger

Expand Down Expand Up @@ -55,7 +54,7 @@ def AscendCudaGraphMixin_make_buffers_cudagraph(
)

input_buffers["kv_start_indices"] = -torch.ones(
(max_batches), dtype=torch.int64, device=device
(max_batches), dtype=torch.int32, device=device
)
return input_buffers

Expand Down Expand Up @@ -367,6 +366,7 @@ def __call__(self, **kwargs):
)
AscendGraphRunner.capturing = True
runner.capture(**kwargs)
AscendGraphRunner.capturing = False
self._runner_map[graph_key] = runner
else:
runner = self._runner_map[graph_key]
Expand Down Expand Up @@ -424,9 +424,6 @@ def get_capture_batch_sizes(self) -> List[int]:
return _get_capture_batch_size_impl(self.cache_config.max_batches)


graph_runner.CUDAGraphRunner = AscendGraphRunner


@dataclass
class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
Expand Down
67 changes: 18 additions & 49 deletions dlinfer/vendor/ascend/torch_npu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,17 @@ def prefill_attention(
value = value.contiguous()
scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(query.shape[-1])
if len(attn_mask):
mask = attn_mask[0].to(query.dtype)
mask = attn_mask[0]
else:
mask = torch.logical_not(
torch.tril(
torch.ones(
max_q_seq_len, max_q_seq_len, dtype=torch.bool, device=query.device
)
mask = torch.triu(
torch.ones(
max_q_seq_len,
max_q_seq_len,
dtype=query.dtype,
device=query.device,
diagonal=1,
)
).to(query.dtype)
)
q_seq_len = q_seq_len.cpu()
if SocVersion.is_Ascend910():
torch.ops.atb._npu_flash_attention(
Expand Down Expand Up @@ -231,7 +233,7 @@ def quant_int8(x, x_scale, x_offset):
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_indices=kv_indices.to(torch.int32),
slot_indices=kv_indices,
)
return key_cache, value_cache

Expand Down Expand Up @@ -272,8 +274,6 @@ def paged_decode_attention(
raise RuntimeError(
"paged_decode_attention does not " "support alibi_slopes yet"
)
if isinstance(block_table, torch.Tensor) and block_table.dtype != torch.int32:
block_table = block_table.to(torch.int32)

query = query.contiguous()
attn_output = attn_output.contiguous()
Expand Down Expand Up @@ -358,57 +358,26 @@ def paged_prefill_attention(
"paged_decode_attention does not " "support alibi_slopes yet"
)

if block_table.dtype != torch.int32:
block_table = block_table.to(torch.int32)

kv_seq_len_list = kv_seq_len.tolist()
scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(query.shape[-1])
query = query.contiguous().view(query.shape[0], 1, -1)
block_num = key_cache.size(0)
key_cache = key_cache.view(block_num, block_size, -1)
value_cache = value_cache.view(block_num, block_size, -1)

attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
query,
key_cache,
value_cache,
pse_shift=None,
query=query,
key=key_cache,
value=value_cache,
atten_mask=attn_mask[0],
actual_seq_lengths=kv_seq_len_list,
actual_seq_lengths_kv=kv_seq_len,
dequant_scale1=None,
quant_scale1=None,
dequant_scale2=None,
quant_scale2=None,
quant_offset2=None,
antiquant_scale=kv_scales,
antiquant_offset=kv_zeros,
block_table=block_table,
query_padding_size=None,
kv_padding_size=None,
key_antiquant_scale=None,
key_antiquant_offset=None,
value_antiquant_scale=None,
value_antiquant_offset=None,
key_shared_prefix=None,
value_shared_prefix=None,
actual_shared_prefix_len=None,
query_rope=None,
key_rope=None,
key_rope_antiquant_scale=None,
num_heads=num_q_heads,
scale=scale_value,
pre_tokens=2147483647,
next_tokens=2147483647,
input_layout="BSH",
block_size=block_size,
actual_seq_lengths=q_seq_len,
actual_seq_lengths_kv=kv_seq_len,
num_key_value_heads=num_kv_heads,
num_heads=num_q_heads,
scale=scale_value,
sparse_mode=0,
inner_precise=1,
block_size=block_size,
antiquant_mode=0,
softmax_lse_flag=False,
key_antiquant_mode=0,
value_antiquant_mode=0,
)

return attn_output
Expand Down