@@ -46,7 +46,8 @@ void set_params_fprop(Flash_fwd_params ¶ms,
4646 int window_size_right,
4747 const float softcap,
4848 bool seqlenq_ngroups_swapped=false ,
49- const bool unpadded_lse=false ) {
49+ const bool unpadded_lse=false ,
50+ const bool is_kvc=false ) {
5051
5152 // Reset the parameters
5253 params = {};
@@ -59,11 +60,19 @@ void set_params_fprop(Flash_fwd_params ¶ms,
5960 params.v_ptr = v.data_ptr ();
6061 // All stride are in elements, not bytes.
6162 params.q_row_stride = q.stride (-3 );
62- params.k_row_stride = k.stride (-3 );
63- params.v_row_stride = v.stride (-3 );
6463 params.q_head_stride = q.stride (-2 );
65- params.k_head_stride = k.stride (-2 );
66- params.v_head_stride = v.stride (-2 );
64+ params.is_kvc_cache = is_kvc;
65+ if (!is_kvc) {
66+ params.k_row_stride = k.stride (-3 );
67+ params.v_row_stride = v.stride (-3 );
68+ params.k_head_stride = k.stride (-2 );
69+ params.v_head_stride = v.stride (-2 );
70+ } else {
71+ params.k_row_stride = k.stride (1 );
72+ params.v_row_stride = v.stride (1 );
73+ // head stride not used
74+ }
75+
6776 params.o_ptr = out.data_ptr ();
6877 params.o_row_stride = out.stride (-3 );
6978 params.o_head_stride = out.stride (-2 );
@@ -159,6 +168,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split
159168 HEADDIM_SWITCH (params.d , [&] {
160169 BOOL_SWITCH (params.is_causal , Is_causal, [&] {
161170 if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
171+ assert (false );
162172 run_mha_fwd_<elem_type, kHeadDim , Is_causal>(params, stream);
163173 } else {
164174 run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim , Is_causal>(params, stream);
@@ -502,6 +512,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
502512 TORCH_CHECK (block_table.dtype () == torch::kInt32 , " block_table must have dtype torch.int32" );
503513 TORCH_CHECK (block_table.stride (-1 ) == 1 , " block_table must have contiguous last dimension" );
504514 }
515+ const bool is_KVC = paged_KV && (block_table.dim () > 2 );
516+
505517
506518 TORCH_CHECK (q.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
507519 TORCH_CHECK (k.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
@@ -514,11 +526,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
514526 const int batch_size = cu_seqlens_q.numel () - 1 ;
515527 int num_heads = sizes[1 ];
516528 const int head_size_og = sizes[2 ];
517- const int num_heads_k = paged_KV ? k.size (2 ) : k.size (1 );
529+ const int num_heads_k = paged_KV ? (!is_KVC ? k.size (2 ): block_table. size ( 1 ) ) : k.size (1 );
518530
519531 if (softcap > 0 .f ) { TORCH_CHECK (p_dropout == 0 .f , " Softcapping does not support dropout for now" ); }
520532
521- const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size (1 );
533+ const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size (1 ) : block_table. size ( 2 ) );
522534 const int num_blocks = !paged_KV ? 0 : k.size (0 );
523535 const int page_block_size = !paged_KV ? 1 : k.size (1 );
524536 TORCH_CHECK (!paged_KV || page_block_size % 16 == 0 , " Paged KV cache block size must be divisible by 16" );
@@ -554,13 +566,29 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
554566 CHECK_SHAPE (k, total_k, num_heads_k, head_size_og);
555567 CHECK_SHAPE (v, total_k, num_heads_k, head_size_og);
556568 } else {
557- CHECK_SHAPE (k, num_blocks, page_block_size, num_heads_k, head_size_og);
558- CHECK_SHAPE (v, num_blocks, page_block_size, num_heads_k, head_size_og);
559- CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
569+ if (!is_KVC) {
570+ CHECK_SHAPE (k, num_blocks, page_block_size, num_heads_k, head_size_og);
571+ CHECK_SHAPE (v, num_blocks, page_block_size, num_heads_k, head_size_og);
572+ CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
573+ } else {
574+ CHECK_SHAPE (k, num_blocks, page_block_size, head_size_og);
575+ CHECK_SHAPE (v, num_blocks, page_block_size, head_size_og);
576+ // [ batch_size, kv_heads, blocks ]
577+ // printf("batch_size=%d, num_heads_k=%d, max_num_blocks_per_seq=%d",
578+ // batch_size, num_heads_k, max_num_blocks_per_seq);
579+ // std::cout << "block_tables shape\n" << block_table.sizes() << std::endl;
580+ CHECK_SHAPE (block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
581+ }
560582 }
561583
584+ bool seqlen_by_head = false ;
562585 CHECK_SHAPE (cu_seqlens_q, batch_size + 1 );
563- CHECK_SHAPE (cu_seqlens_k, batch_size + 1 );
586+ if (!is_KVC) {
587+ CHECK_SHAPE (cu_seqlens_k, batch_size + 1 );
588+ } else {
589+ seqlen_by_head = cu_seqlens_k.size (0 ) > batch_size + 1 ;
590+ // CHECK_SHAPE(cu_seqlens_k, batch_size + 1 + batch_size * num_heads_k + 1);
591+ }
564592 if (seqused_k.has_value ()){
565593 auto seqused_k_ = seqused_k.value ();
566594 TORCH_CHECK (seqused_k_.dtype () == torch::kInt32 , " seqused_k must have dtype int32" );
@@ -639,12 +667,22 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
639667 window_size_right,
640668 softcap,
641669 seqlenq_ngroups_swapped,
642- /* unpadded_lse*/ true );
670+ /* unpadded_lse*/ true ,
671+ /* is_kvc*/ is_KVC);
643672 params.total_q = total_q;
673+ params.seqlen_by_head = seqlen_by_head;
644674
645675 if (paged_KV) {
646676 params.block_table = block_table.data_ptr <int >();
647- params.block_table_batch_stride = block_table.stride (0 );
677+ if (!is_KVC) {
678+ params.block_table_batch_stride = block_table.stride (0 );
679+ } else {
680+ params.kseqlen_batch_stride = num_heads_k;
681+ params.block_table_batch_stride = block_table.stride (0 );
682+ params.block_table_head_stride = block_table.stride (1 );
683+ }
684+ // std::cout << "\n" << k_padded.strides() << std::endl;
685+ // std::cout << k_padded.sizes() << std::endl;
648686 params.k_batch_stride = k_padded.stride (0 );
649687 params.v_batch_stride = v_padded.stride (0 );
650688 }
@@ -759,6 +797,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
759797 TORCH_CHECK (block_table.dtype () == torch::kInt32 , " block_table must have dtype torch.int32" );
760798 TORCH_CHECK (block_table.stride (-1 ) == 1 , " block_table must have contiguous last dimension" );
761799 }
800+ const bool is_KVC = paged_KV && (block_table.dim () > 2 );
762801
763802 const auto sizes = q.sizes ();
764803
@@ -769,12 +808,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
769808 const int num_heads_og = num_heads;
770809 const int head_size_og = sizes[3 ];
771810
772- const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size (1 );
811+ const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size (1 ) : block_table. size ( 2 ) );
773812 const int num_blocks = !paged_KV ? 0 : kcache.size (0 );
774813 const int page_block_size = !paged_KV ? 1 : kcache.size (1 );
775814 TORCH_CHECK (!paged_KV || page_block_size % 16 == 0 , " Paged KV cache block size must be divisible by 16" );
776815 const int seqlen_k = !paged_KV ? kcache.size (1 ) : max_num_blocks_per_seq * page_block_size;
777- const int num_heads_k = kcache.size (2 );
816+ const int num_heads_k = !is_KVC ? kcache.size (2 ) : block_table. size ( 1 );
778817 const int batch_size_c = !paged_KV ? kcache.size (0 ) : batch_size;
779818 TORCH_CHECK (batch_size > 0 , " batch size must be postive" );
780819 TORCH_CHECK (head_size_og <= 256 , " FlashAttention forward only supports head dimension at most 256" );
@@ -802,9 +841,16 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
802841 CHECK_SHAPE (kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
803842 CHECK_SHAPE (vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
804843 } else {
805- CHECK_SHAPE (kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
806- CHECK_SHAPE (vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
807- CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
844+ if (!is_KVC) {
845+ CHECK_SHAPE (kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
846+ CHECK_SHAPE (vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
847+ CHECK_SHAPE (block_table, batch_size, max_num_blocks_per_seq);
848+ } else {
849+ CHECK_SHAPE (kcache, num_blocks, page_block_size, head_size_og);
850+ CHECK_SHAPE (vcache, num_blocks, page_block_size, head_size_og);
851+ // [ batch_size, kv_heads, blocks ]
852+ CHECK_SHAPE (block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
853+ }
808854 }
809855
810856 at::Tensor q_padded, kcache_padded, vcache_padded;
@@ -865,8 +911,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
865911 softmax_scale,
866912 window_size_left,
867913 window_size_right,
868- softcap
869- );
914+ softcap,
915+ /* seqlenq_ngroups_swapped=*/ false ,
916+ /* unpadded_lse=*/ false ,
917+ /* is_kvc=*/ is_KVC);
870918
871919 at::Tensor k, v, k_padded, v_padded;
872920 if (k_.has_value ()) {
@@ -907,8 +955,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
907955 TORCH_CHECK (seqlens_k.dtype () == torch::kInt32 , " seqlens_k must have dtype int32" );
908956 CHECK_DEVICE (seqlens_k);
909957 CHECK_CONTIGUOUS (seqlens_k);
910- CHECK_SHAPE (seqlens_k, batch_size);
958+ if (!is_KVC) {
959+ CHECK_SHAPE (seqlens_k, batch_size);
960+ } else {
961+ CHECK_SHAPE (seqlens_k, batch_size * num_heads_k);
962+ }
911963 params.cu_seqlens_k = static_cast <int *>(seqlens_k.data_ptr ());
964+ params.seqlen_by_head = is_KVC;
912965 }
913966 params.is_seqlens_k_cumulative = !(seqlens_k_.has_value ());
914967
@@ -954,7 +1007,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
9541007
9551008 if (paged_KV) {
9561009 params.block_table = block_table.data_ptr <int >();
957- params.block_table_batch_stride = block_table.stride (0 );
1010+ if (!is_KVC) {
1011+ params.block_table_batch_stride = block_table.stride (0 );
1012+ } else {
1013+ params.kseqlen_batch_stride = num_heads_k;
1014+ params.block_table_batch_stride = block_table.stride (0 );
1015+ params.block_table_head_stride = block_table.stride (1 );
1016+ }
9581017 }
9591018 params.page_block_size = page_block_size;
9601019
0 commit comments