diff --git a/dlinfer/ops/llm.py b/dlinfer/ops/llm.py index f5c056ca..63d2acf1 100644 --- a/dlinfer/ops/llm.py +++ b/dlinfer/ops/llm.py @@ -110,8 +110,8 @@ def prefill_attention( value: Tensor, q_start_loc: Tensor, q_seq_len: Tensor, - max_q_seq_len: int, - num_q_heads: int, + max_q_seq_len: Tensor, + num_q_heads: Tensor, num_kv_heads: int, attn_mask: Sequence[Optional[Tensor]], softmax_scale: Optional[float], @@ -128,7 +128,7 @@ def prefill_attention( value (Tensor): The value tensor. q_start_loc (Tensor): The start location of each query sequence. q_seq_len (Tensor): The length of each query sequence. - max_q_seq_len (int): The maximum length of any query sequence. + max_q_seq_len (Tensor): The maximum length of any query sequence. num_q_heads (int): The number of query heads. num_kv_heads (int): The number of key/value heads. attn_mask (Sequence[Optional[Tensor]]): A sequence of optional attention masks, one for each batch. @@ -223,7 +223,7 @@ def paged_decode_attention( block_table: Tensor, block_size: int, kv_seq_len: Tensor, - max_kv_seq_len: int, + max_kv_seq_len: Tensor, num_q_heads: int, num_kv_heads: int, softmax_scale: Optional[float], @@ -245,7 +245,7 @@ def paged_decode_attention( block in the key/value cache. block_size (int): The size of each block in the input sequence. kv_seq_len (Tensor): The length of each key/value sequence. - max_kv_seq_len (int): The maximum length of any key/value sequence. + max_kv_seq_len (Tensor): The maximum length of any key/value sequence. num_q_heads (int): The number of query heads. num_kv_heads (int): The number of key/value heads. softmax_scale (Optional[float]): The scale factor to apply to the attention logits before the softmax. @@ -301,8 +301,8 @@ def paged_prefill_attention( q_seq_len: Tensor, kv_seq_len: Tensor, cu_seq_lens_kv: Tensor, - max_q_seq_len: int, - max_kv_seq_len: int, + max_q_seq_len: Tensor, + max_kv_seq_len: Tensor, num_q_heads: int, num_kv_heads: int, attn_mask: Sequence[Optional[Tensor]], @@ -328,8 +328,8 @@ def paged_prefill_attention( q_seq_len (Tensor): The length of each query sequence. kv_seq_len (Tensor): The length of each key/value sequence. cu_seq_lens_kv (Tensor): The cumulative sequence lengths of the key/value sequences. - max_q_seq_len (int): The maximum length of any query sequence. - max_kv_seq_len (int): The maximum length of any key/value sequence. + max_q_seq_len (Tensor): The maximum length of any query sequence. + max_kv_seq_len (Tensor): The maximum length of any key/value sequence. num_q_heads (int): The number of query heads. num_kv_heads (int): The number of key/value heads. attn_mask (Sequence[Optional[Tensor]]): A sequence of optional attention masks, one for each batch. diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index 9e359f62..13f7e1eb 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -72,7 +72,7 @@ def prefill_attention( value: Tensor, q_start_loc: Tensor, q_seq_len: Tensor, - max_q_seq_len: int, + max_q_seq_len: Tensor, num_q_heads: int, num_kv_heads: int, attn_mask: Sequence[Optional[Tensor]], @@ -174,7 +174,7 @@ def paged_decode_attention( block_table: Optional[Tensor], block_size: int, kv_seq_len: Tensor, - max_kv_seq_len: int, + max_kv_seq_len: Tensor, num_q_heads: int, num_kv_heads: int, softmax_scale: Optional[float], @@ -236,8 +236,8 @@ def paged_prefill_attention( q_seq_len: Tensor, kv_seq_len: Tensor, cu_seq_lens_kv: Tensor, - max_q_seq_len: int, - max_kv_seq_len: int, + max_q_seq_len: Tensor, + max_kv_seq_len: Tensor, num_q_heads: int, num_kv_heads: int, attn_mask: Sequence[Optional[Tensor]], diff --git a/dlinfer/vendor/camb/camb_ops.py b/dlinfer/vendor/camb/camb_ops.py index 91e4604b..2db1baeb 100644 --- a/dlinfer/vendor/camb/camb_ops.py +++ b/dlinfer/vendor/camb/camb_ops.py @@ -145,7 +145,7 @@ def prefill_attention( value: Tensor, q_start_loc: Tensor, q_seq_len: Tensor, - max_q_seq_len: int, + max_q_seq_len: Tensor, num_q_heads: int, num_kv_heads: int, attn_mask: Sequence[Optional[Tensor]], @@ -195,7 +195,7 @@ def paged_decode_attention( block_table: Optional[Tensor], block_size: int, kv_seq_len: Tensor, - max_kv_seq_len: int, + max_kv_seq_len: Tensor, num_q_heads: int, num_kv_heads: int, softmax_scale: Optional[float], @@ -260,8 +260,8 @@ def paged_prefill_attention( q_seq_len: Tensor, kv_seq_len: Tensor, cu_seq_lens_kv: Tensor, - max_q_seq_len: int, - max_kv_seq_len: int, + max_q_seq_len: Tensor, + max_kv_seq_len: Tensor, num_q_heads: int, num_kv_heads: int, attn_mask: Sequence[Optional[Tensor]], diff --git a/dlinfer/vendor/maca/csrc/CMakeLists.txt b/dlinfer/vendor/maca/csrc/CMakeLists.txt index 0f9d9f1d..26a601d5 100644 --- a/dlinfer/vendor/maca/csrc/CMakeLists.txt +++ b/dlinfer/vendor/maca/csrc/CMakeLists.txt @@ -17,6 +17,7 @@ find_library(torch_python_LIBRARY torch_python PATHS set(DLINFER_VLLM_SRC "pybind.cpp" + "attention/attention_kernels.cu" "pos_encoding_kernels.cu" "moe_align_block_size_kernels.cu" "moe/topk_softmax_kernels.cu" diff --git a/dlinfer/vendor/maca/csrc/attention/attention_kernels.cu b/dlinfer/vendor/maca/csrc/attention/attention_kernels.cu index e44f1274..ea87810d 100644 --- a/dlinfer/vendor/maca/csrc/attention/attention_kernels.cu +++ b/dlinfer/vendor/maca/csrc/attention/attention_kernels.cu @@ -1463,7 +1463,7 @@ template & alibi_slopes, float k_scale, float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, @@ -1494,8 +1494,9 @@ void paged_attention_v1_launcher( constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + DIVIDE_ROUND_UP(512, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_seq_len * sizeof(float); + int V_VEC_SIZE = 16 / sizeof(CACHE_T); int NUM_V_VECS_PER_THREAD = head_size / V_VEC_SIZE; int NUM_COLS_PER_ITER = MAX(WARP_SIZE / NUM_V_VECS_PER_THREAD, 1); @@ -1590,7 +1591,7 @@ void paged_attention_v1( double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] - int64_t block_size, int64_t max_seq_len, + int64_t block_size, torch::Tensor& max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, diff --git a/dlinfer/vendor/maca/csrc/ops.h b/dlinfer/vendor/maca/csrc/ops.h index 5b455973..74ffa67b 100644 --- a/dlinfer/vendor/maca/csrc/ops.h +++ b/dlinfer/vendor/maca/csrc/ops.h @@ -6,7 +6,7 @@ #include void paged_attention_v1(torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); diff --git a/dlinfer/vendor/maca/csrc/pybind.cpp b/dlinfer/vendor/maca/csrc/pybind.cpp index 90c44bcd..52231dbe 100644 --- a/dlinfer/vendor/maca/csrc/pybind.cpp +++ b/dlinfer/vendor/maca/csrc/pybind.cpp @@ -18,6 +18,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // vLLM custom ops pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); + // Attention ops + // Compute the attention between an input query and the cached + // keys/values using PagedAttention. + ops.def("paged_attention_v1", + &paged_attention_v1, + "paged_attention_v1(" + " Tensor! out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " Tensor max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float k_scale, float v_scale," + " int tp_rank, int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def("rotary_embedding", diff --git a/dlinfer/vendor/maca/maca_ops.py b/dlinfer/vendor/maca/maca_ops.py index d1392b0b..a8ec5366 100644 --- a/dlinfer/vendor/maca/maca_ops.py +++ b/dlinfer/vendor/maca/maca_ops.py @@ -102,7 +102,7 @@ def prefill_attention( value: Tensor, q_start_loc: Tensor, q_seq_len: Tensor, - max_q_seq_len: int, + max_q_seq_len: Tensor, num_q_heads: int, num_kv_heads: int, attn_mask: Sequence[Optional[Tensor]], @@ -207,7 +207,7 @@ def paged_decode_attention( block_table: Optional[Tensor], block_size: int, kv_seq_len: Tensor, - max_kv_seq_len: int, + max_kv_seq_len: Tensor, num_q_heads: int, num_kv_heads: int, softmax_scale: Optional[float], @@ -225,7 +225,7 @@ def paged_decode_attention( num_kv_heads = value_cache.size(1) block_size = value_cache.size(2) output = torch.empty_like(query) - custom_ops.paged_attention_v1( + maca_ext_ops.paged_attention_v1( output, query, key_cache, @@ -263,8 +263,8 @@ def paged_prefill_attention( q_seq_len: Tensor, kv_seq_len: Tensor, cu_seq_lens_kv: Tensor, - max_q_seq_len: int, - max_kv_seq_len: int, + max_q_seq_len: Tensor, + max_kv_seq_len: Tensor, num_q_heads: int, num_kv_heads: int, attn_mask: Sequence[Optional[Tensor]],