diff --git a/AUTHORS b/AUTHORS index e35a781665..e30ecbd552 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1 +1,2 @@ -Tri Dao, trid@cs.stanford.edu \ No newline at end of file +Tri Dao, trid@cs.stanford.edu +Andrew O'Neill (Samsung SDSA), a.oneill@samsung.com \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 74ebbce380..a7c62e7d99 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -150,6 +150,7 @@ if (FA2_ENABLED) SOURCES csrc/flash_attn/flash_api.cpp csrc/flash_attn/flash_api_sparse.cpp + csrc/flash_attn/tree_attention.cpp csrc/flash_attn/flash_api_torch_lib.cpp ${FA2_GEN_SRCS} COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS} diff --git a/benchmarks/benchmark_tree_attention.py b/benchmarks/benchmark_tree_attention.py new file mode 100644 index 0000000000..be941a1e31 --- /dev/null +++ b/benchmarks/benchmark_tree_attention.py @@ -0,0 +1,343 @@ +# Copyright (c) 2025, Samsung SDSA. + +import random +import torch + +from vllm_flash_attn.utils.benchmark import benchmark_forward + +from vllm_flash_attn.flash_attn_interface import ( + flash_attn_varlen_func, + tree_attention, +) +from vllm_flash_attn.utils.tree import ( + create_tree_mask, + generate_q_and_block_kvcache, + treeify_output, +) + + +def run_tree_attention_benchmark( + seqlen_q: int = 1024, + seqlen_k: int = 1024, + spec_len: tuple[int] = (8,8), + random_seq_len: bool = False, + random_spec_len: bool = False, + batch_size: int = 8, + nheads: int = 16, + head_dim: int = 128, + paged_kv_block_size: int = 256, + dtype: torch.dtype = torch.float16, + device: str = "cuda", +): + """ + Benchmark tree_attention vs flash_attn_varlen_func performance. + + Similar to test_paged_tree_attention but focused on performance measurement. + """ + print("Benchmarking with:") + print(f" seqlen_q: {seqlen_q}, seqlen_k: {seqlen_k}") + print(f" spec_len: {spec_len}, random_seq_len: {random_seq_len}, random_spec_len: {random_spec_len}") + print(f" batch_size: {batch_size}, nheads: {nheads}, head_dim: {head_dim}") + print(f" paged_kv_block_size: {paged_kv_block_size}, dtype: {dtype}") + + torch.set_default_device(device) + torch.cuda.manual_seed_all(42) # Fixed seed for reproducibility + + # Generate random sequence lengths and spec lengths similar to the test + if random_seq_len: + q_seqlens = [seqlen_q + random.randint(0, 20) for _ in range(batch_size)] + k_seqlens = [seqlen_k + random.randint(0, 20) for _ in range(batch_size)] + else: + q_seqlens = [seqlen_q]*batch_size + k_seqlens = [seqlen_k]*batch_size + + if random_spec_len: + speclens = [(spec_len[0]+random.randint(0, 7), spec_len[1]+random.randint(1, 2)) for _ in range(batch_size)] + else: + speclens = [spec_len]*batch_size + + # Generate test data using the utility function + ( + q_spec_tree, + q_seqlens_tree, + q_spec_batch, + q_seqlens_batch, + tree_block_table, + k_spec_tree, + v_spec_tree, + k_seqlens_tree, + batch_block_table, + k_spec_batch, + v_spec_batch, + k_seqlens_batch, + ) = generate_q_and_block_kvcache( + q_seqlens, k_seqlens, speclens, paged_kv_block_size, nheads, head_dim, device, dtype + ) + + # Create tree mask and cumulative sequence lengths + tree_mask = create_tree_mask(speclens, device) + tree_mask_lens = torch.tensor([0] + [i*j for i,j in speclens], dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) + cu_seqlens_q_tree = torch.tensor([0] + q_seqlens_tree, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) + seqused_k_tree = torch.tensor(k_seqlens_tree, dtype=torch.int32) + cu_seqlens_q_batch = torch.tensor([0] + q_seqlens_batch, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) + seqused_k_batch = torch.tensor(k_seqlens_batch, dtype=torch.int32) + + + print("\nRunning benchmarks...") + + # Benchmark tree_attention + _, tree_measurement = benchmark_forward( + tree_attention, + q_spec_tree, + k_spec_tree, + v_spec_tree, + max(q_seqlens_tree), + cu_seqlens_q_tree, + max(k_seqlens_tree), + tree_mask, + tree_mask_lens, + seqused_k=seqused_k_tree, + block_table=tree_block_table, + desc="tree_attention", + verbose=False + ) + tree_time = tree_measurement.mean + print(f"tree_attention average time: {tree_time:.6f} seconds") + + # Benchmark flash_attn_varlen_func + _, varlen_measurement = benchmark_forward( + flash_attn_varlen_func, + q_spec_batch, + k_spec_batch, + v_spec_batch, + max(q_seqlens_batch), + cu_seqlens_q_batch, + max(k_seqlens_batch), + seqused_k=seqused_k_batch, + causal=True, + block_table=batch_block_table, + desc="flash_attn_varlen_func", + verbose=False + ) + varlen_time = varlen_measurement.mean + print(f"flash_attn_varlen_func average time: {varlen_time:.6f} seconds") + + # Calculate speedup + if varlen_time > 0: + speedup = varlen_time / tree_time + print(f"Speedup (varlen/tree): {speedup:.2f}x") + if speedup > 1: + print(f"tree_attention is {speedup:.2f}x faster") + else: + print(f"flash_attn_varlen_func is {1/speedup:.2f}x faster") + + # Verify correctness + print("\nVerifying correctness...") + tree_output = tree_attention( + q_spec_tree, + k_spec_tree, + v_spec_tree, + max(q_seqlens_tree), + cu_seqlens_q_tree, + max(k_seqlens_tree), + tree_mask, + tree_mask_lens, + seqused_k=seqused_k_tree, + block_table=tree_block_table, + ) + varlen_output = flash_attn_varlen_func( + q_spec_batch, + k_spec_batch, + v_spec_batch, + max(q_seqlens_batch), + cu_seqlens_q_batch, + max(k_seqlens_batch), + seqused_k=seqused_k_batch, + causal=True, + block_table=batch_block_table, + ) + varlen_output_treeified = treeify_output(varlen_output, q_seqlens, speclens) + try: + torch.testing.assert_close(tree_output, varlen_output_treeified, atol=2e-2, rtol=1e-2) + except AssertionError as e: + print("✗ Outputs differ significantly!") + print(e) + else: + print("✓ Outputs match within tolerance") + finally: + max_diff = torch.max(torch.abs(tree_output - varlen_output_treeified)).item() + print(f"Maximum difference between outputs: {max_diff:.6f}") + + return { + 'tree_time': tree_time, + 'varlen_time': varlen_time, + 'speedup': varlen_time / tree_time if varlen_time > 0 else float('inf'), + 'max_diff': max_diff, + 'config': { + 'seqlen_q': seqlen_q, + 'seqlen_k': seqlen_k, + 'batch_size': batch_size, + 'nheads': nheads, + 'head_dim': head_dim, + 'paged_kv_block_size': paged_kv_block_size, + 'dtype': str(dtype), + 'q_spec_tree.shape': q_spec_tree.shape, + 'k_spec_tree.shape': k_spec_tree.shape, + 'tree_mask.shape': tree_mask.shape, + } + } + + +def run_decoding_benchmark(): + """Run benchmarks for decoding scenario with seqlen_q=0.""" + configs = [ + # Small sequences with different spec_len and block sizes + {'seqlen_q': 0, 'seqlen_k': 128, 'batch_size': 4, 'nheads': 8, 'head_dim': 128, 'spec_len': (1, 2), 'paged_kv_block_size': 16}, + {'seqlen_q': 0, 'seqlen_k': 256, 'batch_size': 4, 'nheads': 8, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 16}, + + # Medium sequences with varied spec_len and block sizes + {'seqlen_q': 0, 'seqlen_k': 512, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (1, 2), 'paged_kv_block_size': 256}, + {'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (3, 4), 'paged_kv_block_size': 256}, + + # Large sequences with larger block sizes + {'seqlen_q': 0, 'seqlen_k': 2048, 'batch_size': 4, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 512}, + + # Different head dimensions with varied block sizes + {'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 64, 'spec_len': (1, 2), 'paged_kv_block_size': 256}, + {'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 256, 'spec_len': (2, 3), 'paged_kv_block_size': 512}, + + # Different batch sizes with randomization and block sizes + {'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 2, 'nheads': 16, 'head_dim': 128, 'spec_len': (1, 2), 'random_spec_len': True, 'paged_kv_block_size': 16}, + {'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 16, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'random_seq_len': True, 'paged_kv_block_size': 256}, + + # High spec_len scenarios with different block sizes + {'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (4, 5), 'paged_kv_block_size': 256}, + {'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (6, 8), 'paged_kv_block_size': 512}, + + # Block size comparison scenarios + {'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 16}, + {'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 256}, + {'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 512}, + ] + + print("=" * 80) + print("DECODING BENCHMARK (seqlen_q=0)") + print("=" * 80) + print("This benchmark represents the decoding scenario where tree attention") + print("can be compared against batch expansion for generation tasks.") + print("=" * 80) + + results = [] + for i, config in enumerate(configs): + print(f"\n[{i+1}/{len(configs)}] Decoding Configuration:") + result = run_tree_attention_benchmark(**config) + results.append(result) + print("-" * 80) + + # Summary + print("\n" + "=" * 80) + print("DECODING BENCHMARK SUMMARY") + print("=" * 80) + print(f"{'Config':<18} {'Tree(ms)':<10} {'Varlen(ms)':<12} {'Speedup':<10} {'Max Diff':<12}") + print("-" * 80) + + for i, result in enumerate(results): + config = result['config'] + config_str = f"{config['seqlen_q']}:{config['seqlen_k']}:{config['tree_mask.shape'][0]}:{config['paged_kv_block_size']}" + tree_ms = result['tree_time'] * 1000 + varlen_ms = result['varlen_time'] * 1000 + speedup = result['speedup'] + max_diff = result['max_diff'] + + print(f"{config_str:<18} {tree_ms:<10.3f} {varlen_ms:<12.3f} {speedup:<10.2f}x {max_diff:<12.6f}") + + return results + + +def run_comprehensive_benchmark(): + """Run benchmarks across different configurations.""" + configs = [ + # Small sequences with different spec_len and block sizes + {'seqlen_q': 128, 'seqlen_k': 128, 'batch_size': 4, 'nheads': 8, 'head_dim': 128, 'spec_len': (1, 2), 'paged_kv_block_size': 16}, + {'seqlen_q': 256, 'seqlen_k': 256, 'batch_size': 4, 'nheads': 8, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 16}, + + # Medium sequences with varied spec_len and block sizes + {'seqlen_q': 512, 'seqlen_k': 512, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (1, 2), 'paged_kv_block_size': 256}, + {'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (3, 4), 'paged_kv_block_size': 256}, + + # Large sequences with larger block sizes + {'seqlen_q': 2048, 'seqlen_k': 2048, 'batch_size': 4, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 512}, + + # Different head dimensions with varied block sizes + {'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 64, 'spec_len': (1, 2), 'paged_kv_block_size': 256}, + {'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 256, 'spec_len': (2, 3), 'paged_kv_block_size': 512}, + + # Different batch sizes with randomization and block sizes + {'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 2, 'nheads': 16, 'head_dim': 128, 'spec_len': (1, 2), 'random_spec_len': True, 'paged_kv_block_size': 16}, + {'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 16, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'random_seq_len': True, 'paged_kv_block_size': 256}, + + # High spec_len scenarios with different block sizes + {'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (4, 5), 'paged_kv_block_size': 256}, + {'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (6, 8), 'paged_kv_block_size': 512}, + + # Mixed randomization scenarios with block sizes + {'seqlen_q': 512, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'random_seq_len': True, 'random_spec_len': True, 'paged_kv_block_size': 256}, + + # Block size comparison scenarios + {'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 16}, + {'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 256}, + {'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 512}, + ] + + print("=" * 80) + print("COMPREHENSIVE TREE ATTENTION BENCHMARK") + print("=" * 80) + + results = [] + for i, config in enumerate(configs): + print(f"\n[{i+1}/{len(configs)}] Configuration:") + result = run_tree_attention_benchmark(**config) + results.append(result) + print("-" * 80) + + # Summary + print("\n" + "=" * 80) + print("BENCHMARK SUMMARY") + print("=" * 80) + print(f"{'Config':<18} {'Tree(ms)':<10} {'Varlen(ms)':<12} {'Speedup':<10} {'Max Diff':<12}") + print("-" * 80) + + for i, result in enumerate(results): + config = result['config'] + config_str = f"{config['seqlen_q']}:{config['seqlen_k']}:{config['tree_mask.shape'][0]}:{config['paged_kv_block_size']}" + tree_ms = result['tree_time'] * 1000 + varlen_ms = result['varlen_time'] * 1000 + speedup = result['speedup'] + max_diff = result['max_diff'] + + print(f"{config_str:<18} {tree_ms:<10.3f} {varlen_ms:<12.3f} {speedup:<10.2f}x {max_diff:<12.6f}") + + return results + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("CUDA is not available. This benchmark requires GPU.") + exit(1) + + print("Tree Attention vs Flash Attention Varlen Benchmark") + print(f"PyTorch version: {torch.__version__}") + print(f"CUDA version: {torch.version.cuda}") + print(f"Device: {torch.cuda.get_device_name()}") + + # Run single benchmark + print("\n" + "=" * 80) + print("SINGLE BENCHMARK (1024x1024, batch=8)") + print("=" * 80) + run_tree_attention_benchmark() + + # Run decoding benchmark + run_decoding_benchmark() + + # Run comprehensive benchmark + run_comprehensive_benchmark() \ No newline at end of file diff --git a/csrc/flash_attn/flash_api_torch_lib.cpp b/csrc/flash_attn/flash_api_torch_lib.cpp index d1299c54cd..c70d35e79a 100644 --- a/csrc/flash_attn/flash_api_torch_lib.cpp +++ b/csrc/flash_attn/flash_api_torch_lib.cpp @@ -101,6 +101,31 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_ const bool return_softmax, std::optional gen_); +/////////////////////////// From flash_api_tree.cpp ////////////////////////// + +std::vector +tree_attention(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &leftpad_k_, // batch_size + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const float softcap, + const bool return_softmax, + std::optional gen_, + const at::Tensor &tree_mask, + const at::Tensor &tree_mask_lens); + + /** * Torch Library Registration */ @@ -134,6 +159,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "bool is_causal, float softcap, bool return_softmax, " "Generator? gen) -> Tensor[]"); ops.impl("varlen_fwd_sparse", torch::kCUDA, &mha_varlen_fwd_sparse); + + ops.def("tree_attention(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor cu_seqlens_q, " + "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? block_table, Tensor? alibi_slopes, " + "int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, " + "float softcap, bool return_softmax, " + "Generator? gen, Tensor tree_mask, Tensor tree_mask_lens) -> Tensor[]"); + ops.impl("tree_attention", torch::kCUDA, make_pytorch_shim(&tree_attention)); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME); diff --git a/csrc/flash_attn/src/block_info_tree.h b/csrc/flash_attn/src/block_info_tree.h new file mode 100644 index 0000000000..ea75db49e6 --- /dev/null +++ b/csrc/flash_attn/src/block_info_tree.h @@ -0,0 +1,31 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao, Samsung SDSA. + ******************************************************************************/ + +#pragma once + +#include "namespace_config.h" +#include "block_info.h" + +namespace FLASH_NAMESPACE { + +template +struct BlockInfoTree : BlockInfo { + + template + __device__ BlockInfoTree(const Params ¶ms, const int bidb) + : BlockInfo(params, bidb) + , sum_s_tree(params.tree_mask_lens_ptr[bidb]) + , actual_tree_len(params.tree_mask_lens_ptr[bidb+1]-sum_s_tree) + { + } + + __forceinline__ __device__ uint32_t tree_offset() const { + return uint32_t(sum_s_tree); + } + + const int sum_s_tree; + const int actual_tree_len; +}; + +} // FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim128_bf16_sm80.cu new file mode 100644 index 0000000000..6e661c10a1 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim128_bf16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim128_fp16_sm80.cu new file mode 100644 index 0000000000..2cda05fdc0 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim128_fp16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim160_bf16_sm80.cu new file mode 100644 index 0000000000..a398dd8baf --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim160_bf16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim160(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim160_fp16_sm80.cu new file mode 100644 index 0000000000..36958d0fb0 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim160_fp16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim160(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim192_bf16_sm80.cu new file mode 100644 index 0000000000..fed1f84304 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim192_bf16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim192_fp16_sm80.cu new file mode 100644 index 0000000000..464515a251 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim192_fp16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim256_bf16_sm80.cu new file mode 100644 index 0000000000..c21ea843b3 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim256_bf16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim256_fp16_sm80.cu new file mode 100644 index 0000000000..e416b9ddb4 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim256_fp16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim32_bf16_sm80.cu new file mode 100644 index 0000000000..7764f099f8 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim32_bf16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim32_fp16_sm80.cu new file mode 100644 index 0000000000..f964b3736a --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim32_fp16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim64_bf16_sm80.cu new file mode 100644 index 0000000000..2a812bcb1a --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim64_bf16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim64_fp16_sm80.cu new file mode 100644 index 0000000000..4250cf0777 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim64_fp16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim96_bf16_sm80.cu new file mode 100644 index 0000000000..68c49393d8 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim96_bf16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_hdim96_fp16_sm80.cu new file mode 100644 index 0000000000..e28d43e505 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_hdim96_fp16_sm80.cu @@ -0,0 +1,14 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + run_mha_fwd_tree_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_kernel.h b/csrc/flash_attn/src/flash_fwd_tree_kernel.h new file mode 100644 index 0000000000..af254218b0 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_kernel.h @@ -0,0 +1,1300 @@ +/****************************************************************************** +* Copyright (c) 2025, Tri Dao, Samsung SDSA. +******************************************************************************/ + +#pragma once + +#include "namespace_config.h" +#include "philox_unpack.cuh" // For at::cuda::philox::unpack + +#include + +#include +#include +#include + +#include "block_info_tree.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" +#include "mask_tree.h" +#include "dropout.h" +#include "rotary.h" + +namespace FLASH_NAMESPACE { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bidb, const int bidh, const int m_block, const BlockInfoTree &binfo) { + // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path. + // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick. + // Otherwise, it's written as (h, b, seqlen_q). + const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped; + auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0; + auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + lse_offset); + + auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q); + auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : ( + params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1) + ); + + auto lse_layout = make_layout(lse_shape, lse_stride); + Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout); + auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _); + return local_tile(mLSE_slice, Shape>{}, make_coord(m_block)); +} + + +template +inline __device__ void compute_attn_1rowblock_tree(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + auto seed_offset = at::cuda::philox::unpack(params.philox_args); + FLASH_NAMESPACE::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, + bidb, bidh, tidx, params.h); + + // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might + // exit early and no one saves the rng states. + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { + params.rng_state[0] = std::get<0>(seed_offset); + params.rng_state[1] = std::get<1>(seed_offset); + } + + const BlockInfoTree binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_block_min =0; + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); + // } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + if (n_block_max <= n_block_min) { + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + } + return; + } + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.k_row_stride, params.k_head_stride, _1{})); + Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.v_row_stride, params.v_head_stride, _1{})); + Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + Tensor mMask = make_tensor(make_gmem_ptr(params.tree_mask_ptr+binfo.tree_offset()), + make_shape(binfo.actual_tree_len)); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor tSgS = thr_mma.partition_C(gP); + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) + // if (cute::thread0()) { + // print(tScQ.layout()); printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<0>(tScQ(i))); + // } + // printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<1>(tScQ(i))); + // } + // printf("\n"); + // } + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + // // if (cute::thread(1, 0)) { print(tQsQ); } + // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); + // // if (cute::thread0()) { print(sQNoSwizzle); } + + if (Kernel_traits::Share_Q_K_smem) { + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } + // __syncthreads(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + FLASH_NAMESPACE::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + clear(acc_o); + + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + FLASH_NAMESPACE::TreeMask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, alibi_slope); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1; + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + FLASH_NAMESPACE::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); + } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16, mMask + ); + + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor rP_drop = make_fragment_like(rP); + cute::copy(rP, rP_drop); + dropout.template apply_dropout( + rP_drop, block_row_idx, block_col_idx, kNWarps + ); + cute::copy(rP_drop, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); + } + if (Is_dropout) { + dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + // if (cute::thread0()) { print(tOrP); } + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + FLASH_NAMESPACE::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + if constexpr (Is_softcap){ + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); + } + + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16, mMask + ); + + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor rP_drop = make_fragment_like(rP); + cute::copy(rP, rP_drop); + dropout.template apply_dropout( + rP_drop, block_row_idx, block_col_idx, kNWarps + ); + cute::copy(rP_drop, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); + } + if (Is_dropout) { + dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock_splitkv_tree(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::GmemTiledCopyO, + typename Kernel_traits::GmemTiledCopyOaccum + >; + using ElementO = std::conditional_t; + + const BlockInfoTree binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((binfo.actual_seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = n_split_idx * n_blocks_per_split; + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; } + } + return; + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + // We move K and V to the last block. + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; + const index_t row_offset_k = block_table == nullptr + ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride + : (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread + const index_t row_offset_v = block_table == nullptr + ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride + : (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + Tensor mMask = make_tensor(make_gmem_ptr(params.tree_mask_ptr+binfo.tree_offset()), + make_shape(binfo.actual_tree_len)); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV; + auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + + Tensor tKgK_ = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK_ = gmem_thr_copy_KV.partition_D(sK); + Tensor tVgV_ = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV_ = gmem_thr_copy_KV.partition_D(sV); + + Tensor tKgK = make_tensor(tKgK_.data(), reshape_thread_tile(tKgK_.layout())); + Tensor tKsK = make_tensor(tKsK_.data(), reshape_thread_tile(tKsK_.layout())); + Tensor tVgV = make_tensor(tVgV_.data(), reshape_thread_tile(tVgV_.layout())); + Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout())); + + if (block_table != nullptr) { + auto final_block_size = binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN; + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block_max - 1, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride, final_block_size); + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block_max - 1, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride, final_block_size); + } + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV_ = gmem_thr_copy_KV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKVcKV = make_tensor(tKVcKV_.data(), reshape_thread_tile(tKVcKV_.layout())); + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // Copy from Knew to K, optionally apply rotary embedding. + if constexpr (Append_KV) { + typename Kernel_traits::GmemTiledCopyRotcossinPaged gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinContPaged gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos_ = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin_ = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont_ = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont_ = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + + Tensor tRgCos = make_tensor(tRgCos_.data(), reshape_thread_tile(tRgCos_.layout())); + Tensor tRgSin = make_tensor(tRgSin_.data(), reshape_thread_tile(tRgSin_.layout())); + Tensor tRgCosCont = make_tensor(tRgCosCont_.data(), reshape_flatten_thread_tile(tRgCosCont_.layout())); + Tensor tRgSinCont = make_tensor(tRgSinCont_.data(), reshape_flatten_thread_tile(tRgSinCont_.layout())); + + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + const index_t row_offset_knew = bidb * params.knew_batch_stride + + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + const index_t row_offset_vnew = bidb * params.vnew_batch_stride + + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV_new; + auto gmem_thr_copy_KV_new = gmem_tiled_copy_KV_new.get_thread_slice(tidx); + Tensor tKgKnew_ = gmem_thr_copy_KV_new.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew_ = gmem_thr_copy_KV_new.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + auto tKgKnew = make_tensor(tKgKnew_.data(), reshape_thread_tile(tKgKnew_.layout())); + auto tVgVnew = make_tensor(tVgVnew_.data(), reshape_thread_tile(tVgVnew_.layout())); + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + auto tKgK_data = tKgK.data(); + auto tVgV_data = tVgV.data(); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + FLASH_NAMESPACE::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + FLASH_NAMESPACE::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + FLASH_NAMESPACE::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + FLASH_NAMESPACE::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + + } + } + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + if (n_block > n_block_copy_min) { + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block - 1, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block - 1, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); + } + } + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + tKgK.data() = tKgK_data; + tVgV.data() = tVgV_data; + } + + // Read Q from gmem to smem, optionally apply rotary embedding. + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + FLASH_NAMESPACE::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (m_block * kBlockM)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + FLASH_NAMESPACE::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } else { + FLASH_NAMESPACE::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + FLASH_NAMESPACE::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + // FLASH_NAMESPACE::cp_async_wait<0>(); + // __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } + // __syncthreads(); + + clear(acc_o); + + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + FLASH_NAMESPACE::TreeMask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, alibi_slope); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1; + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); + } + FLASH_NAMESPACE::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + FLASH_NAMESPACE::copy( + gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + FLASH_NAMESPACE::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); + } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16, mMask + ); + + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + if (n_block > n_block_min) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block - 1, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); + } + FLASH_NAMESPACE::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); + } + + FLASH_NAMESPACE::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + FLASH_NAMESPACE::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + if constexpr (Is_softcap){ + FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); + } + + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block - 1, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); + } + FLASH_NAMESPACE::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16, mMask + ); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + + FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + // if (cute::thread0()) { print(lse); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + if constexpr (Split) { __syncthreads(); } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ? + ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb) + ) + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_tree(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + FLASH_NAMESPACE::compute_attn_1rowblock_tree(params, bidb, bidh, m_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_splitkv_tree(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = Split ? blockIdx.z / params.h : blockIdx.y; + // The block index for the head. + const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; + const int n_split_idx = Split ? blockIdx.y : 0; + const int num_n_splits = Split ? gridDim.y : 1; + FLASH_NAMESPACE::compute_attn_1rowblock_splitkv_tree(params, bidb, bidh, m_block, n_split_idx, num_n_splits); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void combine_attn_seqk_parallel_tree(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + constexpr int kMaxSplits = 1 << Log_max_splits; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1]; + + // The thread and block index. + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + const index_t lse_size = params.b * params.h * params.seqlen_q; + + const index_t row_offset_lse = bidx * kBlockM; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), + Shape, Int>{}, + make_stride(lse_size, _1{})); + + // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. + // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. + Layout flat_layout = make_layout(lse_size); + Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); + auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); + Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); + Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); + + Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); + + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; + + // Read the LSE values from gmem and store them in shared memory, then transpose them. + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + if (row < kMaxSplits) { sLSE[row][col] = lse; } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } + } + // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } + __syncthreads(); + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // kBlockM rows, so each time we load we can load 128 / kBlockM rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); + #pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); + #pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { + if (params.unpadded_lse) { + const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; + if (lse_offset < lse_size) { + gLSE_unpadded(lse_offset) = lse_logsum; + } + } else { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } + } + // Store the scales exp(lse - lse_logsum) in shared memory. + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); } + } + __syncthreads(); + + const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom, ElementAccum>{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } + } + // Load Oaccum in then scale and accumulate to O + for (int split = 0; split < params.num_splits; ++split) { + FLASH_NAMESPACE::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE[split][row]; + #pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { + #pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); + } + } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + } + // if (cute::thread0()) { print_tensor(tOrO); } + + Tensor rO = FLASH_NAMESPACE::convert_type(tOrO); + // Write to gO + #pragma unroll + for (int m = 0; m < size<1>(rO); ++m) { + const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.b * params.h * params.seqlen_q) { + const int batch_idx = idx / (params.h * params.seqlen_q); + const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; + // The index to the rows of Q + const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + + head_idx * params.o_head_stride + row * params.o_row_stride; + #pragma unroll + for (int k = 0; k < size<2>(rO); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(rO))::value>>{}, Stride<_1>{}); + // TODO: Should check if this is using vectorized store, but it seems pretty fast + copy(rO(_, m, k), gO); + // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } + // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); + } + } + } + } +} + +} // namespace FLASH_NAMESPACE diff --git a/csrc/flash_attn/src/flash_fwd_tree_launch_template.h b/csrc/flash_attn/src/flash_fwd_tree_launch_template.h new file mode 100644 index 0000000000..0852cfdbc8 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_launch_template.h @@ -0,0 +1,229 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao, Samsung SDSA. + ******************************************************************************/ + +#pragma once +#include "namespace_config.h" +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include "static_switch.h" +#include "hardware_info.h" +#include "flash_tree.h" +#include "flash_fwd_tree_kernel.h" + +namespace FLASH_NAMESPACE { + +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif + +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +// Use a macro to clean up kernel definitions +#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ +template \ +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params_tree params) + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_tree_kernel, bool Is_dropout, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { + #if defined(ARCH_SUPPORTS_FLASH) + FLASH_NAMESPACE::compute_attn_tree(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_tree_kernel, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { + #if defined(ARCH_SUPPORTS_FLASH) + FLASH_NAMESPACE::compute_attn_splitkv_tree(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { + static_assert(Log_max_splits >= 1); + FLASH_NAMESPACE::combine_attn_seqk_parallel_tree(params); +} + +template +void run_flash_fwd_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + // printf("smem_size = %d\n", smem_size); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool return_softmax = params.p_ptr != nullptr; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + auto kernel = &flash_fwd_tree_kernel; + + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); +} + + +template +void run_flash_splitkv_fwd_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_tree_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); + }); + if (params.num_splits > 1) { + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } +} + +template +void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + constexpr static int kBlockM = 64; // Fixed for all head dimensions + // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, + // and for headdim 192 with block size 64 x 128. + // Also for headdim 160 with block size 64 x 128 after the rotary addition. + constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd_tree>(params, stream); +} + +template +void run_mha_fwd_tree_hdim32(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 32; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_fwd_tree, Is_dropout>(params, stream); + }); +} + +template +void run_mha_fwd_tree_hdim64(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_fwd_tree, Is_dropout>(params, stream); + }); +} + +template +void run_mha_fwd_tree_hdim96(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 96; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_fwd_tree, Is_dropout>(params, stream); + }); +} + +template +void run_mha_fwd_tree_hdim128(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_fwd_tree, Is_dropout>(params, stream); + }); +} + +template +void run_mha_fwd_tree_hdim160(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 160; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_fwd_tree, Is_dropout>(params, stream); + }); +} + +template +void run_mha_fwd_tree_hdim192(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 192; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_fwd_tree, Is_dropout>(params, stream); + }); +} + +template +void run_mha_fwd_tree_hdim224(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 224; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_fwd_tree, Is_dropout>(params, stream); + }); +} + +template +void run_mha_fwd_tree_hdim256(Flash_fwd_params_tree ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 256; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_fwd_tree, Is_dropout>(params, stream); + }); +} +} // FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim128_bf16_sm80.cu new file mode 100644 index 0000000000..291d54e5a2 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim128_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim128_fp16_sm80.cu new file mode 100644 index 0000000000..e5958f3d84 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim128_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim160_bf16_sm80.cu new file mode 100644 index 0000000000..7d08cfd314 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim160_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim160_fp16_sm80.cu new file mode 100644 index 0000000000..65a02acba4 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim160_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim192_bf16_sm80.cu new file mode 100644 index 0000000000..7dd14a522f --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim192_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim192_fp16_sm80.cu new file mode 100644 index 0000000000..5cb71212da --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim192_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim256_bf16_sm80.cu new file mode 100644 index 0000000000..b9db7cd6a0 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim256_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim256_fp16_sm80.cu new file mode 100644 index 0000000000..7397783025 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim256_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim32_bf16_sm80.cu new file mode 100644 index 0000000000..c9a3013ebf --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim32_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim32_fp16_sm80.cu new file mode 100644 index 0000000000..a48acbd520 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim32_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim64_bf16_sm80.cu new file mode 100644 index 0000000000..b03d48225e --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim64_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim64_fp16_sm80.cu new file mode 100644 index 0000000000..43809b25f4 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim64_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim96_bf16_sm80.cu new file mode 100644 index 0000000000..fc7a4a0010 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim96_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_tree_split_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_tree_split_hdim96_fp16_sm80.cu new file mode 100644 index 0000000000..92674bf928 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_tree_split_hdim96_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_tree.h b/csrc/flash_attn/src/flash_tree.h new file mode 100644 index 0000000000..53a005b867 --- /dev/null +++ b/csrc/flash_attn/src/flash_tree.h @@ -0,0 +1,21 @@ +/****************************************************************************** + * Copyright (c) 2025, Tri Dao, Samsung SDSA. + ******************************************************************************/ + +#pragma once + +#include "flash.h" + +namespace FLASH_NAMESPACE { + +struct Flash_fwd_params_tree : public Flash_fwd_params { + uint64_t * __restrict__ tree_mask_ptr; // For dynamically masking in tree attention. + int * __restrict__ tree_mask_lens_ptr; // The length of each of the speculated sequences (batch_size + 1). +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void run_mha_fwd_tree_(Flash_fwd_params_tree ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/generate_kernels.py b/csrc/flash_attn/src/generate_kernels.py index 755ee8fea5..50aeb44abb 100644 --- a/csrc/flash_attn/src/generate_kernels.py +++ b/csrc/flash_attn/src/generate_kernels.py @@ -59,6 +59,26 @@ def get_fwd_sparse_template() -> str: }} // namespace FLASH_NAMESPACE""" +def get_fwd_tree_template() -> str: + return NAMESPACE_INCLUDE + """#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE {{ + +template<> +void run_mha_fwd_tree_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params_tree ¶ms, cudaStream_t stream) {{ + run_mha_fwd_tree_hdim{HEAD_DIM}<{DTYPE}>(params, stream); +}} + +}} // namespace FLASH_NAMESPACE""" + +def get_fwd_split_tree_template() -> str: + return NAMESPACE_INCLUDE + """#include "flash_fwd_tree_launch_template.h" + +namespace FLASH_NAMESPACE {{ + +template void run_mha_fwd_splitkv_dispatch_tree<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params_tree ¶ms, cudaStream_t stream); + +}} // namespace FLASH_NAMESPACE""" @dataclass class Kernel: @@ -75,6 +95,8 @@ def template(self) -> str: "bwd": get_bwd_template, "fwd_split": get_fwd_split_template, "fwd_sparse": get_fwd_sparse_template, + "fwd_tree": get_fwd_tree_template, + "fwd_tree_split": get_fwd_split_tree_template, } template_func = template_funcs[self.direction] return template_func().format( @@ -94,6 +116,9 @@ def get_all_kernels() -> List[Kernel]: # For sparse only generate HEAD_DIM=128 for now since this the only one we use currently for dtype, is_causal, sm in itertools.product(DTYPE_MAP.keys(), IS_CAUSAL, SM): yield Kernel(sm=sm, dtype=dtype, head_dim=128, is_causal=is_causal, direction="fwd_sparse") + for direction in ["fwd_tree", "fwd_tree_split"]: + for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM): + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=True, direction=direction) def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: prelude = """// Copyright (c) 2024, Tri Dao. diff --git a/csrc/flash_attn/src/mask_tree.h b/csrc/flash_attn/src/mask_tree.h new file mode 100644 index 0000000000..3245342ee6 --- /dev/null +++ b/csrc/flash_attn/src/mask_tree.h @@ -0,0 +1,110 @@ +/****************************************************************************** + * Copyright (c) 2025, Tri Dao, Samsung SDSA. + ******************************************************************************/ + +#pragma once +#include "namespace_config.h" + +#include + +namespace FLASH_NAMESPACE { + +using namespace cute; + + +template +struct TreeMask { + + const int max_seqlen_k, max_seqlen_q; + const float alibi_slope; + + __forceinline__ __device__ TreeMask(const int max_seqlen_k, const int max_seqlen_q, + const float alibi_slope=0.f) + : max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) + , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { + }; + + // Causal_mask: whether this particular iteration needs causal masking + template + __forceinline__ __device__ void apply_mask(Tensor &tensor_, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride, + const Tensor &mask) { + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + static constexpr bool Need_masking = Has_alibi || Causal_mask || !Is_even_MN; + // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_even_MN, Need_masking); } + if constexpr (Need_masking) { + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? + // if (thread0()) { + // print(tensor_.layout()); print("\n"); + // print(tensor.layout()); print("\n"); + // } + static constexpr bool Col_idx_only = !Causal_mask; + const int col_idx_offset = col_idx_offset_ + (threadIdx.x % 4) * 2; + if constexpr (Col_idx_only) { + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // No causal, no local + if constexpr (Has_alibi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (!Is_even_MN) { + if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } + } + } + } + } + } else { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if constexpr (Has_alibi) { + tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (Causal_mask) { + if (col_idx >= col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + if constexpr (Tree_mask) { + if (row_idx >= (max_seqlen_q - size<0>(mask)) && row_idx < max_seqlen_q && col_idx >= (max_seqlen_k - size<0>(mask)) && col_idx < col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) += mask(make_coord(row_idx + size<0>(mask)-max_seqlen_q)) & (1ULL << (max_seqlen_k-1-col_idx)) ? 0 : -INFINITY; + } + } + } + if constexpr (!Causal_mask && !Is_even_MN) { + // Causal and Local already handles MN masking + if (col_idx >= max_seqlen_k) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } + } + } + }; +}; + +} // FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/tree_attention.cpp b/csrc/flash_attn/tree_attention.cpp new file mode 100644 index 0000000000..28dcf4e2c0 --- /dev/null +++ b/csrc/flash_attn/tree_attention.cpp @@ -0,0 +1,423 @@ +/****************************************************************************** + * Copyright (c) 2025, Tri Dao, Samsung SDSA. + ******************************************************************************/ + +#include + +#include +#include +#include +#include // For at::Generator and at::PhiloxCudaState +#include "philox_unpack.cuh" // For at::cuda::philox::unpack + +#include + +#include "namespace_config.h" +#include "hardware_info.h" +#include "flash_tree.h" +#include "static_switch.h" + + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +namespace FLASH_NAMESPACE { + + +// +// Bit hacky but for now hook into the existing set_params_fprop, +// set_params_splitkv, and set_params_alibi in flash_api.cpp +// +void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + at::Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_k, + void *p_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + const float softcap, + bool seqlenq_ngroups_swapped=false, + const bool unpadded_lse=false); + +std::tuple +set_params_splitkv( + Flash_fwd_params ¶ms, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_k, + const int max_seqlen_q, + const int head_size_rounded, + const float p_dropout, + const int num_splits, + const int num_sm, + struct c10::TensorOptions opts); + +void set_params_alibi( + Flash_fwd_params ¶ms, + std::optional &alibi_slopes_, + int batch_size, + int num_heads); + + +void set_params_fprop_tree( + Flash_fwd_params_tree ¶ms, + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + at::Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_k, + void *p_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + const float softcap, + void *tree_mask, + void *tree_mask_lens, + bool seqlenq_ngroups_swapped=false, + const bool unpadded_lse=false +) +{ + set_params_fprop( + params, + b, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + seqlen_k_rounded, + h, + h_k, + d, + d_rounded, + q, + k, + v, + out, + cu_seqlens_q_d, + cu_seqlens_k_d, + seqused_k, + p_d, + softmax_lse_d, + p_dropout, + softmax_scale, + window_size_left, + window_size_right, + softcap, + seqlenq_ngroups_swapped, + unpadded_lse + ); + params.tree_mask_ptr = static_cast(tree_mask); + params.tree_mask_lens_ptr = static_cast(tree_mask_lens); +} + +void run_mha_fwd_tree(Flash_fwd_params_tree ¶ms, cudaStream_t stream, bool force_split_kernel=false) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_tree_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch_tree(params, stream); + } + }); + }); +} + + +std::vector +tree_attention( + at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &leftpad_k_, // batch_size + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const float softcap, + const bool return_softmax, + std::optional gen_, + const at::Tensor &tree_mask, + const at::Tensor &tree_mask_lens) +{ + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x_min = cc_major >= 8; + TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(tree_mask.dtype() == torch::kUInt64, "TreeAttention only support uint64 data type for tree_mask"); + TORCH_CHECK(tree_mask_lens.dtype() == torch::kInt32, "TreeAttention only support i32 data type for tree_mask_lens"); + CHECK_DEVICE(tree_mask); CHECK_DEVICE(tree_mask_lens); + CHECK_CONTIGUOUS(tree_mask); CHECK_CONTIGUOUS(tree_mask_lens); + + at::Tensor block_table; + const bool paged_KV = block_table_.has_value(); + if (paged_KV) { + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + } + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size = sizes[2]; + const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int num_blocks = !paged_KV ? 0 : k.size(0); + const int page_block_size = !paged_KV ? 1 : k.size(1); + TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16"); + + void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_k; + if (seqlenq_ngroups_swapped) { + q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); + max_seqlen_q = ngroups; + num_heads = num_heads_k; + cu_seqlens_q_d = nullptr; + } + + const int total_q = q.sizes()[0]; + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + + CHECK_SHAPE(q, total_q, num_heads, head_size); + if (!paged_KV) { + const int total_k = k.size(0); + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size); + } else { + CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } + + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + CHECK_SHAPE(tree_mask_lens, batch_size + 1); + if (seqused_k.has_value()){ + auto seqused_k_ = seqused_k.value(); + TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + CHECK_SHAPE(seqused_k_, batch_size); + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, sizes[0], sizes[1], head_size); + if (seqlenq_ngroups_swapped) { + // NOTE(woosuk): We create a temporary buffer and copy the result to the `out_` tensor eventually. + // This is because we reshaped the `q` tensor for the splik-KV optimization, and the `out_` tensor + // has the same shape as the original `q` tensor, not the reshaped one. + out = torch::empty_like(q); + } + } else { + out = torch::empty_like(q); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + auto opts = q.options(); + auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + at::Tensor p; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } + else { + p = torch::empty({ 0 }, opts); + } + + if (zero_tensors) { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) {p.zero_();} + } + + Flash_fwd_params_tree params; + set_params_fprop_tree( + params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + cu_seqlens_q_d, + cu_seqlens_k.data_ptr(), + seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + -1, + 0, + softcap, + tree_mask.data_ptr(), + tree_mask_lens.data_ptr(), + seqlenq_ngroups_swapped, + /*unpadded_lse*/true + ); + params.total_q = total_q; + + if (paged_KV) { + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + } + params.page_block_size = page_block_size; + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + if (seqlenq_ngroups_swapped) { + // Only apply split-k for decoding + std::tie(softmax_lse_accum, out_accum) = + set_params_splitkv(params, batch_size, num_heads, head_size, + max_seqlen_k, max_seqlen_q, head_size_rounded, + p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts); + } + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); + CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + params.leftpad_k = static_cast(leftpad_k.data_ptr()); + } + + // NOTE(woosuk): Commented out because they are not used in inference. + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + // int64_t counter_offset = params.b * params.h * 32; + // auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + // auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // // Forward kernel will populate memory with the seed and offset. + // params.rng_state = reinterpret_cast(rng_state.data_ptr()); + + // if (p_dropout > 0.0) { + // auto gen = at::get_generator_or_default( + // gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // // See Note [Acquire lock when using random generators] + // std::lock_guard lock(gen->mutex_); + // params.philox_args = gen->philox_cuda_state(counter_offset); + // } + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + if (max_seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd_tree(params, stream, paged_KV); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + if (seqlenq_ngroups_swapped) { + int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; + int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; + out = out.reshape(size_before).transpose(1, 2); + if (out_.has_value()) { + // NOTE(woosuk): In this case, we should avoid `out.reshape(size_after)` because it causes + // a redundant clone operation. Instead, we directly copy the result to the `out_` tensor. + out_.value().view({batch_size, num_heads_k, max_seqlen_q, head_size}).copy_(out); + out = out_.value(); + } else { + out = out.reshape(size_after); + } + // NOTE(woosuk): The two lines are not needed because out_padded and q_padded are not used. + // out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after); + // q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after); + int64_t lse_size_before[] = {num_heads, batch_size, max_seqlen_q}; + int64_t lse_size_after[] = {num_heads * max_seqlen_q, batch_size}; + softmax_lse = softmax_lse.reshape(lse_size_before).transpose(1, 2).reshape(lse_size_after); + } + + return {out, softmax_lse}; +} + +} // FLASH_NAMESPACE \ No newline at end of file diff --git a/tests/test_vllm_flash_attn.py b/tests/test_vllm_flash_attn.py index ddc2ea0383..bb06135d13 100644 --- a/tests/test_vllm_flash_attn.py +++ b/tests/test_vllm_flash_attn.py @@ -4,6 +4,7 @@ # import math +import random from typing import List, Optional, Tuple import pytest @@ -15,6 +16,13 @@ flash_attn_with_kvcache, get_scheduler_metadata, is_fa_version_supported, + tree_attention, +) + +from vllm_flash_attn.utils.tree import ( + create_tree_mask, + generate_q_and_block_kvcache, + treeify_output, ) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] @@ -522,3 +530,234 @@ def test_sparse_attention_varlen( f"{torch.max(torch.abs(out - ref_out))}" torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \ f"{torch.max(torch.abs(lse - ref_lse))}" + + +@pytest.mark.parametrize("num_seqs", [1]) +@pytest.mark.parametrize("seq_len", [300]) +@pytest.mark.parametrize("num_heads", [16]) +def test_tree_masking_no_block_table(num_seqs, seq_len, num_heads): + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + head_size = 128 + dtype=torch.float16 + query = torch.randn(seq_len, num_heads, head_size, dtype=dtype) + key = torch.randn(seq_len, num_heads, head_size, dtype=dtype) + value = torch.randn_like(key) + cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32) + cu_seqlens_k = torch.tensor([0, seq_len], dtype=torch.int32) + mask = torch.tensor([ + 0b100000, + 0b010000, + 0b001000, + 0b100100, + 0b001010, + 0b100001, + ], dtype=torch.uint64) + mask_lens = torch.tensor([0, 6], dtype=torch.int32) + scale = head_size**-0.5 + soft_cap = None + + output, lse = tree_attention( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + tree_mask=mask, + tree_mask_lens=mask_lens, + softmax_scale=scale, + softcap=soft_cap if soft_cap is not None else 0, + fa_version=2, + return_softmax_lse=True, + ) + + q_s1 = torch.cat((query[:seq_len-5], query[seq_len-3:seq_len-2])) + k_s1 = torch.cat((key[:seq_len-5], key[seq_len-3:seq_len-2])) + v_s1 = torch.cat((value[:seq_len-5], value[seq_len-3:seq_len-2])) + cu_seqlens_q_s1 = torch.tensor([0, seq_len-4], dtype=torch.int32) + cu_seqlens_k_s1 = torch.tensor([0, seq_len-4], dtype=torch.int32) + + ref_output_s1, ref_lse = flash_attn_varlen_func( + q=q_s1, + k=k_s1, + v=v_s1, + cu_seqlens_q=cu_seqlens_q_s1, + cu_seqlens_k=cu_seqlens_k_s1, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + softmax_scale=scale, + window_size=(-1, -1), + softcap=soft_cap if soft_cap is not None else 0, + fa_version=2, + return_softmax_lse=True, + causal=True, + ) + + output_s1 = torch.cat((output[:seq_len-5], output[seq_len-3:seq_len-2])) + torch.testing.assert_close(output_s1, ref_output_s1, atol=2e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output_s1 - ref_output_s1))}" + + q_s2 = torch.cat((query[:seq_len-5], query[seq_len-1:])) + k_s2 = torch.cat((key[:seq_len-5], key[seq_len-1:])) + v_s2 = torch.cat((value[:seq_len-5], value[seq_len-1:])) + cu_seqlens_q_s2 = torch.tensor([0, seq_len-4], dtype=torch.int32) + cu_seqlens_k_s2 = torch.tensor([0, seq_len-4], dtype=torch.int32) + + ref_output_s2, ref_lse = flash_attn_varlen_func( + q=q_s2, + k=k_s2, + v=v_s2, + cu_seqlens_q=cu_seqlens_q_s2, + cu_seqlens_k=cu_seqlens_k_s2, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + softmax_scale=scale, + window_size=(-1, -1), + softcap=soft_cap if soft_cap is not None else 0, + fa_version=2, + return_softmax_lse=True, + causal=True, + ) + + output_s2 = torch.cat((output[:seq_len-5], output[seq_len-1:])) + torch.testing.assert_close(output_s2, ref_output_s2, atol=2e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output_s2 - ref_output_s2))}" + + q_s3 = torch.cat((query[:seq_len-6], query[seq_len-5:seq_len-4])) + k_s3 = torch.cat((key[:seq_len-6], key[seq_len-5:seq_len-4])) + v_s3 = torch.cat((value[:seq_len-6], value[seq_len-5:seq_len-4])) + cu_seqlens_q_s3 = torch.tensor([0, seq_len-5], dtype=torch.int32) + cu_seqlens_k_s3 = torch.tensor([0, seq_len-5], dtype=torch.int32) + + ref_output_s3, ref_lse = flash_attn_varlen_func( + q=q_s3, + k=k_s3, + v=v_s3, + cu_seqlens_q=cu_seqlens_q_s3, + cu_seqlens_k=cu_seqlens_k_s3, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + softmax_scale=scale, + window_size=(-1, -1), + softcap=soft_cap if soft_cap is not None else 0, + fa_version=2, + return_softmax_lse=True, + causal=True, + ) + + output_s3 = torch.cat((output[:seq_len-6], output[seq_len-5:seq_len-4])) + torch.testing.assert_close(output_s3, ref_output_s3, atol=2e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output_s3 - ref_output_s3))}" + + + q_s4 = torch.cat((query[:seq_len-6], query[seq_len-4:seq_len-3], query[seq_len-2:seq_len-1])) + k_s4 = torch.cat((key[:seq_len-6], key[seq_len-4:seq_len-3], key[seq_len-2:seq_len-1])) + v_s4 = torch.cat((value[:seq_len-6], value[seq_len-4:seq_len-3], value[seq_len-2:seq_len-1])) + cu_seqlens_q_s4 = torch.tensor([0, seq_len-4], dtype=torch.int32) + cu_seqlens_k_s4 = torch.tensor([0, seq_len-4], dtype=torch.int32) + + ref_output_s4, ref_lse = flash_attn_varlen_func( + q=q_s4, + k=k_s4, + v=v_s4, + cu_seqlens_q=cu_seqlens_q_s4, + cu_seqlens_k=cu_seqlens_k_s4, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + softmax_scale=scale, + window_size=(-1, -1), + softcap=soft_cap if soft_cap is not None else 0, + fa_version=2, + return_softmax_lse=True, + causal=True, + ) + + output_s4 = torch.cat((output[:seq_len-6], output[seq_len-4:seq_len-3], output[seq_len-2:seq_len-1])) + torch.testing.assert_close(output_s4, ref_output_s4, atol=2e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output_s4 - ref_output_s4))}" + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize( +# "seqlen_q,seqlen_k", +# [ +# (1, 239), +# (3, 799), +# (127, 512), +# (127, 513), +# (113, 203), +# (128, 217), +# (113, 211), +# (108, 256), +# (256, 512), +# (1023, 1024), +# ], +# ) +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(7,239), (1023, 1024)]) +@pytest.mark.parametrize("swap_sq_sk", [False, True]) +# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged +@pytest.mark.parametrize("paged_kv_block_size", [16, 256]) +def test_paged_tree_attention(seqlen_q, seqlen_k, swap_sq_sk, d, paged_kv_block_size, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(0) + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + batch_size = 8 + nheads = 9 + q_seqlens = [seqlen_q+random.randint(0, 20) for _ in range(batch_size)] + k_seqlens = [seqlen_k+random.randint(0, 20) for _ in range(batch_size)] + speclens = [(random.randint(1, 8), random.randint(2, 3)) for _ in range(batch_size)] + ( + q_spec_tree, + q_seqlens_tree, + q_spec_batch, + q_seqlens_batch, + tree_block_table, + k_spec_tree, + v_spec_tree, + k_seqlens_tree, + batch_block_table, + k_spec_batch, + v_spec_batch, + k_seqlens_batch, + ) = generate_q_and_block_kvcache(q_seqlens, k_seqlens, speclens, paged_kv_block_size, nheads, d, device, dtype) + tree_mask = create_tree_mask(speclens, device) + tree_mask_lens = torch.tensor([0] + [i*j for i,j in speclens], dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) + cu_seqlens_q_tree = torch.tensor([0] + q_seqlens_tree, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) + seqused_k_tree = torch.tensor(k_seqlens_tree, dtype=torch.int32) + cu_seqlens_q_batch = torch.tensor([0] + q_seqlens_batch, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) + seqused_k_batch = torch.tensor(k_seqlens_batch, dtype=torch.int32) + + out = tree_attention( + q_spec_tree, + k_spec_tree, + v_spec_tree, + max(q_seqlens_tree), + cu_seqlens_q_tree, + max(k_seqlens_tree), + tree_mask, + tree_mask_lens, + seqused_k=seqused_k_tree, + block_table=tree_block_table, + ) + + ref_output = flash_attn_varlen_func( + q_spec_batch, + k_spec_batch, + v_spec_batch, + max(q_seqlens_batch), + cu_seqlens_q_batch, + max(k_seqlens_batch), + seqused_k=seqused_k_batch, + causal=True, + block_table=batch_block_table, + ) + + ref_output_tree = treeify_output(ref_output, q_seqlens, speclens) + torch.testing.assert_close(out, ref_output_tree, atol=2e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(out - ref_output_tree))}" + diff --git a/vllm_flash_attn/__init__.py b/vllm_flash_attn/__init__.py index 4a95ac70ac..31e460ca49 100644 --- a/vllm_flash_attn/__init__.py +++ b/vllm_flash_attn/__init__.py @@ -8,5 +8,6 @@ sparse_attn_func, sparse_attn_varlen_func, is_fa_version_supported, - fa_version_unsupported_reason + fa_version_unsupported_reason, + tree_attention ) \ No newline at end of file diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 06de7fd17b..9a9537bab6 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Tri Dao. +# Copyright (c) 2023, Tri Dao, Samsung SDSA. from typing import Optional, Union, Tuple, List @@ -612,3 +612,170 @@ def sparse_attn_varlen_func( None, ) return (out, softmax_lse) if return_softmax_lse else out + + +def tree_attention( + q, + k, + v, + max_seqlen_q, + cu_seqlens_q, + max_seqlen_k, + tree_mask, + tree_mask_lens, + cu_seqlens_k=None, # only used for non-paged prefill + seqused_k=None, + q_v=None, + dropout_p=0.0, + softmax_scale=None, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + block_table=None, + return_softmax_lse=False, + out=None, + # FA3 Only + scheduler_metadata=None, + q_descale=None, + k_descale=None, + v_descale=None, + # Version selector + fa_version: int = DEFAULT_FA_VERSION, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + The tree_mask is aligned to the bottom right corner of the attention matrix. It overrides the + causal mask in those positions. Each int64 represents a row in the mask. Allowing up to 64 + masked tokens (not including causal mask). A tree_mask for a single sequence might look like + this. + + seq = [a,b,c,d,1,e,2,f,3] + tree_mask = [0b100000, + 0b010000, + 0b101000, + 0b010100, + 0b101010, + 0b010101] + tree_mask_len = [0,6] + + This corresponds to attention on the sequences abcdef, abc123. + + tree_mask_lens is similiar to the cu_seqlens arguments. Since the lengths are variable the + tree_mask for a batch is concatenated into one dimension. The tree_mask_lens contain the start + and end offset of the trees. For example: + + tree 0 = tree_mask[tree_mask_len[0]:tree_mask_len[0+1]] + tree 1 = tree_mask[tree_mask_len[1]:tree_mask_len[1+1]] + tree 2 = tree_mask[tree_mask_len[2]:tree_mask_len[2+1]] + or more generally + tree i = tree_mask[tree_mask_len[i]:tree_mask_len[i+1]] + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + tree_mask: (total_tree), where total_tree = total number of mask. dtype torch.int64. + tree_mask_lens: (batch_size + 1,), dtype torch.int32. The cumulative mask lengths + of the masks in the batch, used to index into tree_mask. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + assert cu_seqlens_k is not None or seqused_k is not None, \ + "cu_seqlens_k or seqused_k must be provided" + assert cu_seqlens_k is None or seqused_k is None, \ + "cu_seqlens_k and seqused_k cannot be provided at the same time" + assert block_table is None or seqused_k is not None, \ + "seqused_k must be provided if block_table is provided" + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + + dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) + + if fa_version == 2: + if scheduler_metadata is not None and q_descale is not None \ + and k_descale is not None and v_descale is not None: + raise NotImplementedError( + "FA2 does not support scheduler_metadata, q_descale, " + "k_descale, v_descale" + ) + out, softmax_lse = torch.ops._vllm_fa2_C.tree_attention( + q, k, v, + out, + cu_seqlens_q, + # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp + # still wants it so we pass all zeros + dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, + seqused_k, + None, + block_table, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + softcap, + return_softmax_lse and dropout_p > 0, + None, + tree_mask, + tree_mask_lens, + ) + # elif fa_version == 3: + # assert alibi_slopes is None, "Alibi is not supported in FA3" + # out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( + # q, k, v, + # None, None, # k_new, v_new + # q_v, + # out, + # cu_seqlens_q, + # cu_seqlens_k, # cu_seqlens_k + # None, # cu_seqlens_k_new + # None, seqused_k, # seqused_q, seqused_k + # max_seqlen_q, max_seqlen_k, + # block_table, + # None, # kv_batch_idx + # None, # leftpad_k + # None, None, None, # rotary_cos, rotary_sin, seqlens_rotary + # q_descale, k_descale, v_descale, + # softmax_scale, + # True, # causal + # real_window_size[0], real_window_size[1], + # softcap, + # True, # rotary_interleaved + # scheduler_metadata, + # 0, # num_splits + # None, # pack_gqa + # 0, # sm_margin + # ) + else: + raise ValueError(f"Unsupported FA version: {fa_version}") + return (out, softmax_lse) if return_softmax_lse else out diff --git a/vllm_flash_attn/utils/__init__.py b/vllm_flash_attn/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_flash_attn/utils/benchmark.py b/vllm_flash_attn/utils/benchmark.py new file mode 100644 index 0000000000..15b30405f2 --- /dev/null +++ b/vllm_flash_attn/utils/benchmark.py @@ -0,0 +1,268 @@ +# Copyright (c) 2023, Tri Dao. +""" Useful functions for writing test code. """ + +import torch +import torch.utils.benchmark as benchmark + + +def benchmark_forward( + fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs +): + """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" + if verbose: + print(desc, "- Forward pass") + + def amp_wrapper(*inputs, **kwinputs): + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + fn(*inputs, **kwinputs) + + t = benchmark.Timer( + stmt="fn_amp(*inputs, **kwinputs)", + globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_backward( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the backward pass of an arbitrary function.""" + if verbose: + print(desc, "- Backward pass") + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError("Grad shape does not match output shape") + + def f(*inputs, y, grad): + # Set .grad to None to avoid extra operation of gradient accumulation + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + y.backward(grad, retain_graph=True) + + t = benchmark.Timer( + stmt="f(*inputs, y=y, grad=grad)", + globals={"f": f, "inputs": inputs, "y": y, "grad": grad}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_combined( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + if verbose: + print(desc, "- Forward + Backward pass") + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError("Grad shape does not match output shape") + + def f(grad, *inputs, **kwinputs): + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + y.backward(grad, retain_graph=True) + + t = benchmark.Timer( + stmt="f(grad, *inputs, **kwinputs)", + globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_fwd_bwd( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + return ( + benchmark_forward( + fn, + *inputs, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_backward( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + ) + + +def benchmark_all( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + return ( + benchmark_forward( + fn, + *inputs, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_backward( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_combined( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + ) + + +def pytorch_profiler( + fn, + *inputs, + trace_filename=None, + backward=False, + amp=False, + amp_dtype=torch.float16, + cpu=False, + verbose=True, + **kwinputs, +): + """Wrap benchmark functions in Pytorch profiler to see CUDA information.""" + if backward: + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + g = torch.randn_like(out) + for _ in range(30): # Warm up + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + # Backward should be done outside autocast + if backward: + out.backward(g, retain_graph=True) + activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [ + torch.profiler.ProfilerActivity.CUDA + ] + with torch.profiler.profile( + activities=activities, + record_shapes=True, + # profile_memory=True, + with_stack=True, + ) as prof: + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + if backward: + out.backward(g, retain_graph=True) + if verbose: + # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) + print(prof.key_averages().table(row_limit=50)) + if trace_filename is not None: + prof.export_chrome_trace(trace_filename) + + +def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + fn(*inputs, **kwinputs) + torch.cuda.synchronize() + mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000) + if verbose: + print(f"{desc} max memory: {mem}GB") + torch.cuda.empty_cache() + return mem diff --git a/vllm_flash_attn/utils/tree.py b/vllm_flash_attn/utils/tree.py new file mode 100644 index 0000000000..008804f1d1 --- /dev/null +++ b/vllm_flash_attn/utils/tree.py @@ -0,0 +1,190 @@ +# Copyright (c) 2025, Samsung SDSA. + +import math +import torch +from einops import rearrange + + +def create_mask(tree: list[int]): + out = [] + for node in tree: + out = [o << 1 for o in out] + mask = 1 + if node != -1: + mask |= out[node] + out.append(mask) + return out + + +def simple_mask(spec_len, num_seq): + mask = [] + for seq in range(num_seq): + pos = -1 + for s in range(spec_len): + mask.append(pos) + pos = len(mask)-1 + return mask + + +def create_tree_mask(speclens, device): + outputs: list[torch.Tensor] = [] + for spec_len, spec_branchs in speclens: + outputs.append(torch.tensor(create_mask(simple_mask(spec_len, spec_branchs)), dtype=torch.uint64, device=device)) + return torch.cat(outputs, dim=0) + + +def create_spec_tree(input_base: torch.Tensor, spec_base: torch.Tensor, seqlens: list[int], speclens: list[tuple[int, int]]): + seq_idx = 0 + spec_idx = 0 + outputs = [] + for seq_len, (spec_len, spec_branchs) in zip(seqlens, speclens): + total_spec_len = spec_len * spec_branchs + outputs += [input_base[seq_idx:seq_idx+seq_len], spec_base[spec_idx:spec_idx+total_spec_len]] + seq_idx += seq_len + spec_idx += total_spec_len + return torch.cat(outputs, dim=0) + + +def create_spec_batch(input_base: torch.Tensor, spec_base: torch.Tensor, seqlens: list[int], speclens: list[tuple[int, int]]): + seq_idx=0 + spec_idx=0 + outputs = [] + for seq_len, (spec_len, spec_branchs) in zip(seqlens, speclens): + for j in range(spec_branchs): + outputs += [input_base[seq_idx:seq_idx+seq_len], spec_base[spec_idx:spec_idx+spec_len]] + spec_idx += spec_len + seq_idx += seq_len + return torch.cat(outputs, dim=0) + + +def tree_seqlens(seqlens, speclens): + return [i+j*b for i, (j, b) in zip(seqlens, speclens)] + + +def batch_seqlens(seqlens, speclens): + seqlens_batch = [] + for i, (j, s) in zip(seqlens, speclens): + seqlens_batch += [i+j]*s + return seqlens_batch + + +def unshuffle_indices(shuffled_indices): + n = len(shuffled_indices) + inverse_indices = [None] * n + + for original_pos, shuffled_pos in enumerate(shuffled_indices): + inverse_indices[shuffled_pos] = original_pos + return inverse_indices + + +def create_block_shuffle(seqlens, paged_kv_block_size, device): + num_blocks = sum(math.ceil(seq / paged_kv_block_size) for seq in seqlens) + block_table = torch.zeros((len(seqlens), math.ceil(max(seqlens)/ paged_kv_block_size)), dtype=torch.int32, device=device) + block_shuffle = torch.randperm(num_blocks, dtype=torch.int32, device=device) + block_unshuffle = torch.tensor(unshuffle_indices(block_shuffle), dtype=torch.int32, device=device) + block_idx = 0 + for i, length in enumerate(seqlens): + blocks_in_seq = math.ceil(length/paged_kv_block_size) + block_table[i, :blocks_in_seq] = block_unshuffle[block_idx:block_idx+blocks_in_seq] + block_idx += blocks_in_seq + return block_table, block_shuffle + + +def to_paged_blocks(sequence: torch.Tensor, seqlens: list[int], block_size: int, nheads: int, d: int, block_table: torch.Tensor): + num_blocks = sum(math.ceil(length / block_size) for length in seqlens) + block_tensor = torch.empty( + num_blocks*block_size, nheads, d, device=sequence.device, dtype=sequence.dtype + ) + + bt_idx = 0 + seq_idx = 0 + for seqlen in seqlens: + block_tensor[bt_idx:bt_idx+seqlen] = sequence[seq_idx:seq_idx+seqlen] + rem = block_size - (seqlen % block_size) if seqlen % block_size != 0 else 0 + bt_idx += seqlen + rem + seq_idx += seqlen + + block_tensor = rearrange(block_tensor, "(num_blocks blocksize) nhead d -> num_blocks blocksize nhead d", blocksize=block_size) + # shuffle blocks based on block table + block_tensor = block_tensor.index_select(0, block_table) + return block_tensor + + +def generate_q_and_block_kvcache(q_seqlens: list[int], k_seqlens: list[int], speclens: list[tuple[int, int]], paged_kv_block_size: int, nheads: int, d: int, device, dtype): + # create the input base and individual spec branchs + q_input_base = torch.randn(sum(q_seqlens), nheads, d, device=device, dtype=dtype) + q_spec_base = torch.randn(sum(a * b for a, b in speclens), nheads, d, device=device, dtype=dtype) + + # from the bases create the q for tree attention + q_spec_tree = create_spec_tree(q_input_base, q_spec_base, q_seqlens, speclens) + q_seqlens_tree = tree_seqlens(q_seqlens, speclens) + + # from the bases create the q for varlen attention + q_spec_batch = create_spec_batch(q_input_base, q_spec_base, q_seqlens, speclens) + q_seqlens_batch = batch_seqlens(q_seqlens, speclens) + + del q_input_base + del q_spec_base + + # create the k base and individual spec branches + k_input_base = torch.randn(sum(k_seqlens), nheads, d, device=device, dtype=dtype) + k_spec_base = torch.randn(sum(a * b for a, b in speclens), nheads, d, device=device, dtype=dtype) + + # from the bases create the k for tree attention + k_tree = create_spec_tree(k_input_base, k_spec_base, k_seqlens, speclens) + k_seqlens_tree = tree_seqlens(k_seqlens, speclens) + tree_block_table, tree_block_shuffle = create_block_shuffle(k_seqlens_tree, paged_kv_block_size, device) + k_spec_tree = to_paged_blocks(k_tree, k_seqlens_tree, paged_kv_block_size, nheads, d, tree_block_shuffle) + + # from the bases create the k for varlen attention + k_batch = create_spec_batch(k_input_base, k_spec_base, k_seqlens, speclens) + k_seqlens_batch = batch_seqlens(k_seqlens, speclens) + batch_block_table, batch_block_shuffle = create_block_shuffle(k_seqlens_batch, paged_kv_block_size, device) + k_spec_batch = to_paged_blocks(k_batch, k_seqlens_batch, paged_kv_block_size, nheads, d, batch_block_shuffle) + + del k_input_base + del k_spec_base + del k_tree + del k_batch + + # create the v base and individual spec branches + v_input_base = torch.randn(sum(k_seqlens), nheads, d, device=device, dtype=dtype) + v_spec_base = torch.randn(sum(a * b for a, b in speclens), nheads, d, device=device, dtype=dtype) + + # from the bases create the v for tree attention + v_tree = create_spec_tree(v_input_base, v_spec_base, k_seqlens, speclens) + v_spec_tree = to_paged_blocks(v_tree, k_seqlens_tree, paged_kv_block_size, nheads, d, tree_block_shuffle) + + # from the bases create the v for varlen attention + v_batch = create_spec_batch(v_input_base, v_spec_base, k_seqlens, speclens) + v_spec_batch = to_paged_blocks(v_batch, k_seqlens_batch, paged_kv_block_size, nheads, d, batch_block_shuffle) + + del v_input_base + del v_spec_base + del v_tree + del v_batch + + return q_spec_tree, q_seqlens_tree, q_spec_batch, q_seqlens_batch, tree_block_table, k_spec_tree, v_spec_tree, k_seqlens_tree, batch_block_table, k_spec_batch, v_spec_batch, k_seqlens_batch + + +def treeify_output(t: torch.Tensor, seqlens: int, speclens: int) -> torch.Tensor: + out = torch.empty(sum(i+j*s for i, (j,s) in zip(seqlens, speclens)), t.shape[1], t.shape[2], device=t.device, dtype=t.dtype) + input_idx = 0 + output_idx = 0 + for seq_len, (spec_len, spec_branchs) in zip(seqlens, speclens): + out[output_idx:output_idx+seq_len] = t[input_idx:input_idx+seq_len] + output_idx += seq_len + for _ in range(spec_branchs): + input_idx += seq_len + out[output_idx:output_idx+spec_len] = t[input_idx:input_idx+spec_len] + output_idx += spec_len + input_idx += spec_len + return out + + +def deblockify(t:torch.Tensor, block_table: torch.Tensor, seqlens: list[int]): + out = [] + for i, seqlen in enumerate(seqlens): + temp = rearrange(t.index_select(0, block_table[i]), "num_blocks blocksize nhead d -> (num_blocks blocksize) nhead d") + out.append(temp[:seqlen]) + return torch.cat(out, dim=0) \ No newline at end of file