@@ -46,7 +46,8 @@ void set_params_fprop(Flash_fwd_params ¶ms,
4646 int window_size_right,
4747 const float softcap,
4848 bool seqlenq_ngroups_swapped=false ,
49- const bool unpadded_lse=false ) {
49+ const bool unpadded_lse=false ,
50+ const bool is_kvc=false ) {
5051
5152 // Reset the parameters
5253 params = {};
@@ -59,11 +60,19 @@ void set_params_fprop(Flash_fwd_params ¶ms,
5960 params.v_ptr = v.data_ptr ();
6061 // All stride are in elements, not bytes.
6162 params.q_row_stride = q.stride (-3 );
62- params.k_row_stride = k.stride (-3 );
63- params.v_row_stride = v.stride (-3 );
6463 params.q_head_stride = q.stride (-2 );
65- params.k_head_stride = k.stride (-2 );
66- params.v_head_stride = v.stride (-2 );
64+ params.is_kvc_cache = is_kvc;
65+ if (!is_kvc) {
66+ params.k_row_stride = k.stride (-3 );
67+ params.v_row_stride = v.stride (-3 );
68+ params.k_head_stride = k.stride (-2 );
69+ params.v_head_stride = v.stride (-2 );
70+ } else {
71+ params.k_row_stride = k.stride (1 );
72+ params.v_row_stride = v.stride (1 );
73+ // head stride not used
74+ }
75+
6776 params.o_ptr = out.data_ptr ();
6877 params.o_row_stride = out.stride (-3 );
6978 params.o_head_stride = out.stride (-2 );
@@ -502,6 +511,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
502511 TORCH_CHECK (block_table.dtype () == torch::kInt32 , " block_table must have dtype torch.int32" );
503512 TORCH_CHECK (block_table.stride (-1 ) == 1 , " block_table must have contiguous last dimension" );
504513 }
514+ const bool is_KVC = paged_KV && (block_table.dim () > 2 );
515+
505516
506517 TORCH_CHECK (q.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
507518 TORCH_CHECK (k.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
@@ -514,11 +525,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
514525 const int batch_size = cu_seqlens_q.numel () - 1 ;
515526 int num_heads = sizes[1 ];
516527 const int head_size_og = sizes[2 ];
517- const int num_heads_k = paged_KV ? k.size (2 ) : k.size (1 );
528+ const int num_heads_k = paged_KV ? (!is_KVC ? k.size (2 ): block_table. size ( 1 ) ) : k.size (1 );
518529
519530 if (softcap > 0 .f ) { TORCH_CHECK (p_dropout == 0 .f , " Softcapping does not support dropout for now" ); }
520531
521- const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size (1 );
532+ const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size (1 ) : block_table. size ( 2 ) );
522533 const int num_blocks = !paged_KV ? 0 : k.size (0 );
523534 const int page_block_size = !paged_KV ? 1 : k.size (1 );
524535 TORCH_CHECK (!paged_KV || page_block_size % 16 == 0 , " Paged KV cache block size must be divisible by 16" );
@@ -554,13 +565,29 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
554565 CHECK_SHAPE (k, total_k, num_heads_k, head_size_og);
555566 CHECK_SHAPE (v, total_k, num_heads_k, head_size_og);
556567 } else {
557- CHECK_SHAPE (k, num_blocks, page_block_size, num_heads_k, head_size_og);
558- CHECK_SHAPE (v, num_blocks, page_block_size, num_heads_k, head_size_og);
559- CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
568+ if (!is_KVC) {
569+ CHECK_SHAPE (k, num_blocks, page_block_size, num_heads_k, head_size_og);
570+ CHECK_SHAPE (v, num_blocks, page_block_size, num_heads_k, head_size_og);
571+ CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
572+ } else {
573+ CHECK_SHAPE (k, num_blocks, page_block_size, head_size_og);
574+ CHECK_SHAPE (v, num_blocks, page_block_size, head_size_og);
575+ // [ batch_size, kv_heads, blocks ]
576+ // printf("batch_size=%d, num_heads_k=%d, max_num_blocks_per_seq=%d",
577+ // batch_size, num_heads_k, max_num_blocks_per_seq);
578+ // std::cout << "block_tables shape\n" << block_table.sizes() << std::endl;
579+ CHECK_SHAPE (block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
580+ }
560581 }
561582
583+ bool seqlen_by_head = false ;
562584 CHECK_SHAPE (cu_seqlens_q, batch_size + 1 );
563- CHECK_SHAPE (cu_seqlens_k, batch_size + 1 );
585+ if (!is_KVC) {
586+ CHECK_SHAPE (cu_seqlens_k, batch_size + 1 );
587+ } else {
588+ seqlen_by_head = cu_seqlens_k.size (0 ) > batch_size + 1 ;
589+ // CHECK_SHAPE(cu_seqlens_k, batch_size + 1 + batch_size * num_heads_k + 1);
590+ }
564591 if (seqused_k.has_value ()){
565592 auto seqused_k_ = seqused_k.value ();
566593 TORCH_CHECK (seqused_k_.dtype () == torch::kInt32 , " seqused_k must have dtype int32" );
@@ -639,12 +666,22 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
639666 window_size_right,
640667 softcap,
641668 seqlenq_ngroups_swapped,
642- /* unpadded_lse*/ true );
669+ /* unpadded_lse*/ true ,
670+ /* is_kvc*/ is_KVC);
643671 params.total_q = total_q;
672+ params.seqlen_by_head = seqlen_by_head;
644673
645674 if (paged_KV) {
646675 params.block_table = block_table.data_ptr <int >();
647- params.block_table_batch_stride = block_table.stride (0 );
676+ if (!is_KVC) {
677+ params.block_table_batch_stride = block_table.stride (0 );
678+ } else {
679+ params.kseqlen_batch_stride = num_heads_k;
680+ params.block_table_batch_stride = block_table.stride (0 );
681+ params.block_table_head_stride = block_table.stride (1 );
682+ }
683+ // std::cout << "\n" << k_padded.strides() << std::endl;
684+ // std::cout << k_padded.sizes() << std::endl;
648685 params.k_batch_stride = k_padded.stride (0 );
649686 params.v_batch_stride = v_padded.stride (0 );
650687 }
@@ -759,6 +796,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
759796 TORCH_CHECK (block_table.dtype () == torch::kInt32 , " block_table must have dtype torch.int32" );
760797 TORCH_CHECK (block_table.stride (-1 ) == 1 , " block_table must have contiguous last dimension" );
761798 }
799+ const bool is_KVC = paged_KV && (block_table.dim () > 2 );
762800
763801 const auto sizes = q.sizes ();
764802
@@ -769,12 +807,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
769807 const int num_heads_og = num_heads;
770808 const int head_size_og = sizes[3 ];
771809
772- const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size (1 );
810+ const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size (1 ) : block_table. size ( 2 ) );
773811 const int num_blocks = !paged_KV ? 0 : kcache.size (0 );
774812 const int page_block_size = !paged_KV ? 1 : kcache.size (1 );
775813 TORCH_CHECK (!paged_KV || page_block_size % 16 == 0 , " Paged KV cache block size must be divisible by 16" );
776814 const int seqlen_k = !paged_KV ? kcache.size (1 ) : max_num_blocks_per_seq * page_block_size;
777- const int num_heads_k = kcache.size (2 );
815+ const int num_heads_k = !is_KVC ? kcache.size (2 ) : block_table. size ( 1 );
778816 const int batch_size_c = !paged_KV ? kcache.size (0 ) : batch_size;
779817 TORCH_CHECK (batch_size > 0 , " batch size must be postive" );
780818 TORCH_CHECK (head_size_og <= 256 , " FlashAttention forward only supports head dimension at most 256" );
@@ -802,9 +840,16 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
802840 CHECK_SHAPE (kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
803841 CHECK_SHAPE (vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
804842 } else {
805- CHECK_SHAPE (kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
806- CHECK_SHAPE (vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
807- CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
843+ if (!is_KVC) {
844+ CHECK_SHAPE (kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
845+ CHECK_SHAPE (vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
846+ CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
847+ } else {
848+ CHECK_SHAPE (kcache, num_blocks, page_block_size, head_size_og);
849+ CHECK_SHAPE (vcache, num_blocks, page_block_size, head_size_og);
850+ // [ batch_size, kv_heads, blocks ]
851+ CHECK_SHAPE (block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
852+ }
808853 }
809854
810855 at::Tensor q_padded, kcache_padded, vcache_padded;
@@ -865,8 +910,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
865910 softmax_scale,
866911 window_size_left,
867912 window_size_right,
868- softcap
869- );
913+ softcap,
914+ /* seqlenq_ngroups_swapped=*/ false ,
915+ /* unpadded_lse=*/ false ,
916+ /* is_kvc=*/ is_KVC);
870917
871918 at::Tensor k, v, k_padded, v_padded;
872919 if (k_.has_value ()) {
@@ -907,8 +954,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
907954 TORCH_CHECK (seqlens_k.dtype () == torch::kInt32 , " seqlens_k must have dtype int32" );
908955 CHECK_DEVICE (seqlens_k);
909956 CHECK_CONTIGUOUS (seqlens_k);
910- CHECK_SHAPE (seqlens_k, batch_size);
957+ if (!is_KVC) {
958+ CHECK_SHAPE (seqlens_k, batch_size);
959+ } else {
960+ CHECK_SHAPE (seqlens_k, batch_size * num_heads_k);
961+ }
911962 params.cu_seqlens_k = static_cast <int *>(seqlens_k.data_ptr ());
963+ params.seqlen_by_head = is_KVC;
912964 }
913965 params.is_seqlens_k_cumulative = !(seqlens_k_.has_value ());
914966
@@ -954,7 +1006,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
9541006
9551007 if (paged_KV) {
9561008 params.block_table = block_table.data_ptr <int >();
957- params.block_table_batch_stride = block_table.stride (0 );
1009+ if (!is_KVC) {
1010+ params.block_table_batch_stride = block_table.stride (0 );
1011+ } else {
1012+ params.kseqlen_batch_stride = num_heads_k;
1013+ params.block_table_batch_stride = block_table.stride (0 );
1014+ params.block_table_head_stride = block_table.stride (1 );
1015+ }
9581016 }
9591017 params.page_block_size = page_block_size;
9601018
0 commit comments