Skip to content

Commit feaab45

Browse files
committed
feat: implement tree attention mask support for FlashAttention-2
Add comprehensive tree attention implementation including: - Core tree attention CUDA kernels for multiple head dimensions (32-256) and precisions (FP16/BF16) - Tree mask utilities for structured attention patterns in speculative decoding - Python interface and C++ bindings for tree_attention function - Benchmarking suite comparing tree attention vs varlen flash attention - Test utilities for paged KV cache with tree attention patterns This enables efficient speculative decoding by avoiding batch expansion, providing memory-efficient attention computation for tree-structured token generation patterns commonly used in speculative sampling. vllm-project/vllm#18327 Signed-off-by: Andrew O'Neill <a.oneill@samsung.com>
1 parent 57b4e68 commit feaab45

File tree

45 files changed

+3734
-3
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+3734
-3
lines changed

AUTHORS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
Tri Dao, trid@cs.stanford.edu
1+
Tri Dao, trid@cs.stanford.edu
2+
Andrew O'Neill (Samsung SDSA), a.oneill@samsung.com

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ if (FA2_ENABLED)
150150
SOURCES
151151
csrc/flash_attn/flash_api.cpp
152152
csrc/flash_attn/flash_api_sparse.cpp
153+
csrc/flash_attn/tree_attention.cpp
153154
csrc/flash_attn/flash_api_torch_lib.cpp
154155
${FA2_GEN_SRCS}
155156
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
# Copyright (c) 2025, Samsung SDSA.
2+
3+
import random
4+
import torch
5+
6+
from vllm_flash_attn.utils.benchmark import benchmark_forward
7+
8+
from vllm_flash_attn.flash_attn_interface import (
9+
flash_attn_varlen_func,
10+
tree_attention,
11+
)
12+
from vllm_flash_attn.utils.tree import (
13+
create_tree_mask,
14+
generate_q_and_block_kvcache,
15+
treeify_output,
16+
)
17+
18+
19+
def run_tree_attention_benchmark(
20+
seqlen_q: int = 1024,
21+
seqlen_k: int = 1024,
22+
spec_len: tuple[int] = (8,8),
23+
random_seq_len: bool = False,
24+
random_spec_len: bool = False,
25+
batch_size: int = 8,
26+
nheads: int = 16,
27+
head_dim: int = 128,
28+
paged_kv_block_size: int = 256,
29+
dtype: torch.dtype = torch.float16,
30+
device: str = "cuda",
31+
):
32+
"""
33+
Benchmark tree_attention vs flash_attn_varlen_func performance.
34+
35+
Similar to test_paged_tree_attention but focused on performance measurement.
36+
"""
37+
print("Benchmarking with:")
38+
print(f" seqlen_q: {seqlen_q}, seqlen_k: {seqlen_k}")
39+
print(f" spec_len: {spec_len}, random_seq_len: {random_seq_len}, random_spec_len: {random_spec_len}")
40+
print(f" batch_size: {batch_size}, nheads: {nheads}, head_dim: {head_dim}")
41+
print(f" paged_kv_block_size: {paged_kv_block_size}, dtype: {dtype}")
42+
43+
torch.set_default_device(device)
44+
torch.cuda.manual_seed_all(42) # Fixed seed for reproducibility
45+
46+
# Generate random sequence lengths and spec lengths similar to the test
47+
if random_seq_len:
48+
q_seqlens = [seqlen_q + random.randint(0, 20) for _ in range(batch_size)]
49+
k_seqlens = [seqlen_k + random.randint(0, 20) for _ in range(batch_size)]
50+
else:
51+
q_seqlens = [seqlen_q]*batch_size
52+
k_seqlens = [seqlen_k]*batch_size
53+
54+
if random_spec_len:
55+
speclens = [(spec_len[0]+random.randint(0, 7), spec_len[1]+random.randint(1, 2)) for _ in range(batch_size)]
56+
else:
57+
speclens = [spec_len]*batch_size
58+
59+
# Generate test data using the utility function
60+
(
61+
q_spec_tree,
62+
q_seqlens_tree,
63+
q_spec_batch,
64+
q_seqlens_batch,
65+
tree_block_table,
66+
k_spec_tree,
67+
v_spec_tree,
68+
k_seqlens_tree,
69+
batch_block_table,
70+
k_spec_batch,
71+
v_spec_batch,
72+
k_seqlens_batch,
73+
) = generate_q_and_block_kvcache(
74+
q_seqlens, k_seqlens, speclens, paged_kv_block_size, nheads, head_dim, device, dtype
75+
)
76+
77+
# Create tree mask and cumulative sequence lengths
78+
tree_mask = create_tree_mask(speclens, device)
79+
tree_mask_lens = torch.tensor([0] + [i*j for i,j in speclens], dtype=torch.int32).cumsum(dim=0, dtype=torch.int32)
80+
cu_seqlens_q_tree = torch.tensor([0] + q_seqlens_tree, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32)
81+
seqused_k_tree = torch.tensor(k_seqlens_tree, dtype=torch.int32)
82+
cu_seqlens_q_batch = torch.tensor([0] + q_seqlens_batch, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32)
83+
seqused_k_batch = torch.tensor(k_seqlens_batch, dtype=torch.int32)
84+
85+
86+
print("\nRunning benchmarks...")
87+
88+
# Benchmark tree_attention
89+
_, tree_measurement = benchmark_forward(
90+
tree_attention,
91+
q_spec_tree,
92+
k_spec_tree,
93+
v_spec_tree,
94+
max(q_seqlens_tree),
95+
cu_seqlens_q_tree,
96+
max(k_seqlens_tree),
97+
tree_mask,
98+
tree_mask_lens,
99+
seqused_k=seqused_k_tree,
100+
block_table=tree_block_table,
101+
desc="tree_attention",
102+
verbose=False
103+
)
104+
tree_time = tree_measurement.mean
105+
print(f"tree_attention average time: {tree_time:.6f} seconds")
106+
107+
# Benchmark flash_attn_varlen_func
108+
_, varlen_measurement = benchmark_forward(
109+
flash_attn_varlen_func,
110+
q_spec_batch,
111+
k_spec_batch,
112+
v_spec_batch,
113+
max(q_seqlens_batch),
114+
cu_seqlens_q_batch,
115+
max(k_seqlens_batch),
116+
seqused_k=seqused_k_batch,
117+
causal=True,
118+
block_table=batch_block_table,
119+
desc="flash_attn_varlen_func",
120+
verbose=False
121+
)
122+
varlen_time = varlen_measurement.mean
123+
print(f"flash_attn_varlen_func average time: {varlen_time:.6f} seconds")
124+
125+
# Calculate speedup
126+
if varlen_time > 0:
127+
speedup = varlen_time / tree_time
128+
print(f"Speedup (varlen/tree): {speedup:.2f}x")
129+
if speedup > 1:
130+
print(f"tree_attention is {speedup:.2f}x faster")
131+
else:
132+
print(f"flash_attn_varlen_func is {1/speedup:.2f}x faster")
133+
134+
# Verify correctness
135+
print("\nVerifying correctness...")
136+
tree_output = tree_attention(
137+
q_spec_tree,
138+
k_spec_tree,
139+
v_spec_tree,
140+
max(q_seqlens_tree),
141+
cu_seqlens_q_tree,
142+
max(k_seqlens_tree),
143+
tree_mask,
144+
tree_mask_lens,
145+
seqused_k=seqused_k_tree,
146+
block_table=tree_block_table,
147+
)
148+
varlen_output = flash_attn_varlen_func(
149+
q_spec_batch,
150+
k_spec_batch,
151+
v_spec_batch,
152+
max(q_seqlens_batch),
153+
cu_seqlens_q_batch,
154+
max(k_seqlens_batch),
155+
seqused_k=seqused_k_batch,
156+
causal=True,
157+
block_table=batch_block_table,
158+
)
159+
varlen_output_treeified = treeify_output(varlen_output, q_seqlens, speclens)
160+
try:
161+
torch.testing.assert_close(tree_output, varlen_output_treeified, atol=2e-2, rtol=1e-2)
162+
except AssertionError as e:
163+
print("✗ Outputs differ significantly!")
164+
print(e)
165+
else:
166+
print("✓ Outputs match within tolerance")
167+
finally:
168+
max_diff = torch.max(torch.abs(tree_output - varlen_output_treeified)).item()
169+
print(f"Maximum difference between outputs: {max_diff:.6f}")
170+
171+
return {
172+
'tree_time': tree_time,
173+
'varlen_time': varlen_time,
174+
'speedup': varlen_time / tree_time if varlen_time > 0 else float('inf'),
175+
'max_diff': max_diff,
176+
'config': {
177+
'seqlen_q': seqlen_q,
178+
'seqlen_k': seqlen_k,
179+
'batch_size': batch_size,
180+
'nheads': nheads,
181+
'head_dim': head_dim,
182+
'paged_kv_block_size': paged_kv_block_size,
183+
'dtype': str(dtype),
184+
'q_spec_tree.shape': q_spec_tree.shape,
185+
'k_spec_tree.shape': k_spec_tree.shape,
186+
'tree_mask.shape': tree_mask.shape,
187+
}
188+
}
189+
190+
191+
def run_decoding_benchmark():
192+
"""Run benchmarks for decoding scenario with seqlen_q=0."""
193+
configs = [
194+
# Small sequences with different spec_len and block sizes
195+
{'seqlen_q': 0, 'seqlen_k': 128, 'batch_size': 4, 'nheads': 8, 'head_dim': 128, 'spec_len': (1, 2), 'paged_kv_block_size': 16},
196+
{'seqlen_q': 0, 'seqlen_k': 256, 'batch_size': 4, 'nheads': 8, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 16},
197+
198+
# Medium sequences with varied spec_len and block sizes
199+
{'seqlen_q': 0, 'seqlen_k': 512, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (1, 2), 'paged_kv_block_size': 256},
200+
{'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (3, 4), 'paged_kv_block_size': 256},
201+
202+
# Large sequences with larger block sizes
203+
{'seqlen_q': 0, 'seqlen_k': 2048, 'batch_size': 4, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 512},
204+
205+
# Different head dimensions with varied block sizes
206+
{'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 64, 'spec_len': (1, 2), 'paged_kv_block_size': 256},
207+
{'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 256, 'spec_len': (2, 3), 'paged_kv_block_size': 512},
208+
209+
# Different batch sizes with randomization and block sizes
210+
{'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 2, 'nheads': 16, 'head_dim': 128, 'spec_len': (1, 2), 'random_spec_len': True, 'paged_kv_block_size': 16},
211+
{'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 16, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'random_seq_len': True, 'paged_kv_block_size': 256},
212+
213+
# High spec_len scenarios with different block sizes
214+
{'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (4, 5), 'paged_kv_block_size': 256},
215+
{'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (6, 8), 'paged_kv_block_size': 512},
216+
217+
# Block size comparison scenarios
218+
{'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 16},
219+
{'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 256},
220+
{'seqlen_q': 0, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 512},
221+
]
222+
223+
print("=" * 80)
224+
print("DECODING BENCHMARK (seqlen_q=0)")
225+
print("=" * 80)
226+
print("This benchmark represents the decoding scenario where tree attention")
227+
print("can be compared against batch expansion for generation tasks.")
228+
print("=" * 80)
229+
230+
results = []
231+
for i, config in enumerate(configs):
232+
print(f"\n[{i+1}/{len(configs)}] Decoding Configuration:")
233+
result = run_tree_attention_benchmark(**config)
234+
results.append(result)
235+
print("-" * 80)
236+
237+
# Summary
238+
print("\n" + "=" * 80)
239+
print("DECODING BENCHMARK SUMMARY")
240+
print("=" * 80)
241+
print(f"{'Config':<18} {'Tree(ms)':<10} {'Varlen(ms)':<12} {'Speedup':<10} {'Max Diff':<12}")
242+
print("-" * 80)
243+
244+
for i, result in enumerate(results):
245+
config = result['config']
246+
config_str = f"{config['seqlen_q']}:{config['seqlen_k']}:{config['tree_mask.shape'][0]}:{config['paged_kv_block_size']}"
247+
tree_ms = result['tree_time'] * 1000
248+
varlen_ms = result['varlen_time'] * 1000
249+
speedup = result['speedup']
250+
max_diff = result['max_diff']
251+
252+
print(f"{config_str:<18} {tree_ms:<10.3f} {varlen_ms:<12.3f} {speedup:<10.2f}x {max_diff:<12.6f}")
253+
254+
return results
255+
256+
257+
def run_comprehensive_benchmark():
258+
"""Run benchmarks across different configurations."""
259+
configs = [
260+
# Small sequences with different spec_len and block sizes
261+
{'seqlen_q': 128, 'seqlen_k': 128, 'batch_size': 4, 'nheads': 8, 'head_dim': 128, 'spec_len': (1, 2), 'paged_kv_block_size': 16},
262+
{'seqlen_q': 256, 'seqlen_k': 256, 'batch_size': 4, 'nheads': 8, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 16},
263+
264+
# Medium sequences with varied spec_len and block sizes
265+
{'seqlen_q': 512, 'seqlen_k': 512, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (1, 2), 'paged_kv_block_size': 256},
266+
{'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (3, 4), 'paged_kv_block_size': 256},
267+
268+
# Large sequences with larger block sizes
269+
{'seqlen_q': 2048, 'seqlen_k': 2048, 'batch_size': 4, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 512},
270+
271+
# Different head dimensions with varied block sizes
272+
{'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 64, 'spec_len': (1, 2), 'paged_kv_block_size': 256},
273+
{'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 256, 'spec_len': (2, 3), 'paged_kv_block_size': 512},
274+
275+
# Different batch sizes with randomization and block sizes
276+
{'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 2, 'nheads': 16, 'head_dim': 128, 'spec_len': (1, 2), 'random_spec_len': True, 'paged_kv_block_size': 16},
277+
{'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 16, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'random_seq_len': True, 'paged_kv_block_size': 256},
278+
279+
# High spec_len scenarios with different block sizes
280+
{'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (4, 5), 'paged_kv_block_size': 256},
281+
{'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (6, 8), 'paged_kv_block_size': 512},
282+
283+
# Mixed randomization scenarios with block sizes
284+
{'seqlen_q': 512, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'random_seq_len': True, 'random_spec_len': True, 'paged_kv_block_size': 256},
285+
286+
# Block size comparison scenarios
287+
{'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 16},
288+
{'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 256},
289+
{'seqlen_q': 1024, 'seqlen_k': 1024, 'batch_size': 8, 'nheads': 16, 'head_dim': 128, 'spec_len': (2, 3), 'paged_kv_block_size': 512},
290+
]
291+
292+
print("=" * 80)
293+
print("COMPREHENSIVE TREE ATTENTION BENCHMARK")
294+
print("=" * 80)
295+
296+
results = []
297+
for i, config in enumerate(configs):
298+
print(f"\n[{i+1}/{len(configs)}] Configuration:")
299+
result = run_tree_attention_benchmark(**config)
300+
results.append(result)
301+
print("-" * 80)
302+
303+
# Summary
304+
print("\n" + "=" * 80)
305+
print("BENCHMARK SUMMARY")
306+
print("=" * 80)
307+
print(f"{'Config':<18} {'Tree(ms)':<10} {'Varlen(ms)':<12} {'Speedup':<10} {'Max Diff':<12}")
308+
print("-" * 80)
309+
310+
for i, result in enumerate(results):
311+
config = result['config']
312+
config_str = f"{config['seqlen_q']}:{config['seqlen_k']}:{config['tree_mask.shape'][0]}:{config['paged_kv_block_size']}"
313+
tree_ms = result['tree_time'] * 1000
314+
varlen_ms = result['varlen_time'] * 1000
315+
speedup = result['speedup']
316+
max_diff = result['max_diff']
317+
318+
print(f"{config_str:<18} {tree_ms:<10.3f} {varlen_ms:<12.3f} {speedup:<10.2f}x {max_diff:<12.6f}")
319+
320+
return results
321+
322+
323+
if __name__ == "__main__":
324+
if not torch.cuda.is_available():
325+
print("CUDA is not available. This benchmark requires GPU.")
326+
exit(1)
327+
328+
print("Tree Attention vs Flash Attention Varlen Benchmark")
329+
print(f"PyTorch version: {torch.__version__}")
330+
print(f"CUDA version: {torch.version.cuda}")
331+
print(f"Device: {torch.cuda.get_device_name()}")
332+
333+
# Run single benchmark
334+
print("\n" + "=" * 80)
335+
print("SINGLE BENCHMARK (1024x1024, batch=8)")
336+
print("=" * 80)
337+
run_tree_attention_benchmark()
338+
339+
# Run decoding benchmark
340+
run_decoding_benchmark()
341+
342+
# Run comprehensive benchmark
343+
run_comprehensive_benchmark()

0 commit comments

Comments
 (0)