1+ # Copyright (c) 2025, Samsung SDSA.
2+
3+ import random
4+ import torch
5+
6+ from vllm_flash_attn .utils .benchmark import benchmark_forward
7+
8+ from vllm_flash_attn .flash_attn_interface import (
9+ flash_attn_varlen_func ,
10+ tree_attention ,
11+ )
12+ from vllm_flash_attn .utils .tree import (
13+ create_tree_mask ,
14+ generate_q_and_block_kvcache ,
15+ treeify_output ,
16+ )
17+
18+
19+ def run_tree_attention_benchmark (
20+ seqlen_q : int = 1024 ,
21+ seqlen_k : int = 1024 ,
22+ spec_len : tuple [int ] = (8 ,8 ),
23+ random_seq_len : bool = False ,
24+ random_spec_len : bool = False ,
25+ batch_size : int = 8 ,
26+ nheads : int = 16 ,
27+ head_dim : int = 128 ,
28+ paged_kv_block_size : int = 256 ,
29+ dtype : torch .dtype = torch .float16 ,
30+ device : str = "cuda" ,
31+ ):
32+ """
33+ Benchmark tree_attention vs flash_attn_varlen_func performance.
34+
35+ Similar to test_paged_tree_attention but focused on performance measurement.
36+ """
37+ print ("Benchmarking with:" )
38+ print (f" seqlen_q: { seqlen_q } , seqlen_k: { seqlen_k } " )
39+ print (f" spec_len: { spec_len } , random_seq_len: { random_seq_len } , random_spec_len: { random_spec_len } " )
40+ print (f" batch_size: { batch_size } , nheads: { nheads } , head_dim: { head_dim } " )
41+ print (f" paged_kv_block_size: { paged_kv_block_size } , dtype: { dtype } " )
42+
43+ torch .set_default_device (device )
44+ torch .cuda .manual_seed_all (42 ) # Fixed seed for reproducibility
45+
46+ # Generate random sequence lengths and spec lengths similar to the test
47+ if random_seq_len :
48+ q_seqlens = [seqlen_q + random .randint (0 , 20 ) for _ in range (batch_size )]
49+ k_seqlens = [seqlen_k + random .randint (0 , 20 ) for _ in range (batch_size )]
50+ else :
51+ q_seqlens = [seqlen_q ]* batch_size
52+ k_seqlens = [seqlen_k ]* batch_size
53+
54+ if random_spec_len :
55+ speclens = [(spec_len [0 ]+ random .randint (0 , 7 ), spec_len [1 ]+ random .randint (1 , 2 )) for _ in range (batch_size )]
56+ else :
57+ speclens = [spec_len ]* batch_size
58+
59+ # Generate test data using the utility function
60+ (
61+ q_spec_tree ,
62+ q_seqlens_tree ,
63+ q_spec_batch ,
64+ q_seqlens_batch ,
65+ tree_block_table ,
66+ k_spec_tree ,
67+ v_spec_tree ,
68+ k_seqlens_tree ,
69+ batch_block_table ,
70+ k_spec_batch ,
71+ v_spec_batch ,
72+ k_seqlens_batch ,
73+ ) = generate_q_and_block_kvcache (
74+ q_seqlens , k_seqlens , speclens , paged_kv_block_size , nheads , head_dim , device , dtype
75+ )
76+
77+ # Create tree mask and cumulative sequence lengths
78+ tree_mask = create_tree_mask (speclens , device )
79+ tree_mask_lens = torch .tensor ([0 ] + [i * j for i ,j in speclens ], dtype = torch .int32 ).cumsum (dim = 0 , dtype = torch .int32 )
80+ cu_seqlens_q_tree = torch .tensor ([0 ] + q_seqlens_tree , dtype = torch .int32 ).cumsum (dim = 0 , dtype = torch .int32 )
81+ seqused_k_tree = torch .tensor (k_seqlens_tree , dtype = torch .int32 )
82+ cu_seqlens_q_batch = torch .tensor ([0 ] + q_seqlens_batch , dtype = torch .int32 ).cumsum (dim = 0 , dtype = torch .int32 )
83+ seqused_k_batch = torch .tensor (k_seqlens_batch , dtype = torch .int32 )
84+
85+
86+ print ("\n Running benchmarks..." )
87+
88+ # Benchmark tree_attention
89+ _ , tree_measurement = benchmark_forward (
90+ tree_attention ,
91+ q_spec_tree ,
92+ k_spec_tree ,
93+ v_spec_tree ,
94+ max (q_seqlens_tree ),
95+ cu_seqlens_q_tree ,
96+ max (k_seqlens_tree ),
97+ tree_mask ,
98+ tree_mask_lens ,
99+ seqused_k = seqused_k_tree ,
100+ block_table = tree_block_table ,
101+ desc = "tree_attention" ,
102+ verbose = False
103+ )
104+ tree_time = tree_measurement .mean
105+ print (f"tree_attention average time: { tree_time :.6f} seconds" )
106+
107+ # Benchmark flash_attn_varlen_func
108+ _ , varlen_measurement = benchmark_forward (
109+ flash_attn_varlen_func ,
110+ q_spec_batch ,
111+ k_spec_batch ,
112+ v_spec_batch ,
113+ max (q_seqlens_batch ),
114+ cu_seqlens_q_batch ,
115+ max (k_seqlens_batch ),
116+ seqused_k = seqused_k_batch ,
117+ causal = True ,
118+ block_table = batch_block_table ,
119+ desc = "flash_attn_varlen_func" ,
120+ verbose = False
121+ )
122+ varlen_time = varlen_measurement .mean
123+ print (f"flash_attn_varlen_func average time: { varlen_time :.6f} seconds" )
124+
125+ # Calculate speedup
126+ if varlen_time > 0 :
127+ speedup = varlen_time / tree_time
128+ print (f"Speedup (varlen/tree): { speedup :.2f} x" )
129+ if speedup > 1 :
130+ print (f"tree_attention is { speedup :.2f} x faster" )
131+ else :
132+ print (f"flash_attn_varlen_func is { 1 / speedup :.2f} x faster" )
133+
134+ # Verify correctness
135+ print ("\n Verifying correctness..." )
136+ tree_output = tree_attention (
137+ q_spec_tree ,
138+ k_spec_tree ,
139+ v_spec_tree ,
140+ max (q_seqlens_tree ),
141+ cu_seqlens_q_tree ,
142+ max (k_seqlens_tree ),
143+ tree_mask ,
144+ tree_mask_lens ,
145+ seqused_k = seqused_k_tree ,
146+ block_table = tree_block_table ,
147+ )
148+ varlen_output = flash_attn_varlen_func (
149+ q_spec_batch ,
150+ k_spec_batch ,
151+ v_spec_batch ,
152+ max (q_seqlens_batch ),
153+ cu_seqlens_q_batch ,
154+ max (k_seqlens_batch ),
155+ seqused_k = seqused_k_batch ,
156+ causal = True ,
157+ block_table = batch_block_table ,
158+ )
159+ varlen_output_treeified = treeify_output (varlen_output , q_seqlens , speclens )
160+ try :
161+ torch .testing .assert_close (tree_output , varlen_output_treeified , atol = 2e-2 , rtol = 1e-2 )
162+ except AssertionError as e :
163+ print ("✗ Outputs differ significantly!" )
164+ print (e )
165+ else :
166+ print ("✓ Outputs match within tolerance" )
167+ finally :
168+ max_diff = torch .max (torch .abs (tree_output - varlen_output_treeified )).item ()
169+ print (f"Maximum difference between outputs: { max_diff :.6f} " )
170+
171+ return {
172+ 'tree_time' : tree_time ,
173+ 'varlen_time' : varlen_time ,
174+ 'speedup' : varlen_time / tree_time if varlen_time > 0 else float ('inf' ),
175+ 'max_diff' : max_diff ,
176+ 'config' : {
177+ 'seqlen_q' : seqlen_q ,
178+ 'seqlen_k' : seqlen_k ,
179+ 'batch_size' : batch_size ,
180+ 'nheads' : nheads ,
181+ 'head_dim' : head_dim ,
182+ 'paged_kv_block_size' : paged_kv_block_size ,
183+ 'dtype' : str (dtype ),
184+ 'q_spec_tree.shape' : q_spec_tree .shape ,
185+ 'k_spec_tree.shape' : k_spec_tree .shape ,
186+ 'tree_mask.shape' : tree_mask .shape ,
187+ }
188+ }
189+
190+
191+ def run_decoding_benchmark ():
192+ """Run benchmarks for decoding scenario with seqlen_q=0."""
193+ configs = [
194+ # Small sequences with different spec_len and block sizes
195+ {'seqlen_q' : 0 , 'seqlen_k' : 128 , 'batch_size' : 4 , 'nheads' : 8 , 'head_dim' : 128 , 'spec_len' : (1 , 2 ), 'paged_kv_block_size' : 16 },
196+ {'seqlen_q' : 0 , 'seqlen_k' : 256 , 'batch_size' : 4 , 'nheads' : 8 , 'head_dim' : 128 , 'spec_len' : (2 , 3 ), 'paged_kv_block_size' : 16 },
197+
198+ # Medium sequences with varied spec_len and block sizes
199+ {'seqlen_q' : 0 , 'seqlen_k' : 512 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (1 , 2 ), 'paged_kv_block_size' : 256 },
200+ {'seqlen_q' : 0 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (3 , 4 ), 'paged_kv_block_size' : 256 },
201+
202+ # Large sequences with larger block sizes
203+ {'seqlen_q' : 0 , 'seqlen_k' : 2048 , 'batch_size' : 4 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (2 , 3 ), 'paged_kv_block_size' : 512 },
204+
205+ # Different head dimensions with varied block sizes
206+ {'seqlen_q' : 0 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 64 , 'spec_len' : (1 , 2 ), 'paged_kv_block_size' : 256 },
207+ {'seqlen_q' : 0 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 256 , 'spec_len' : (2 , 3 ), 'paged_kv_block_size' : 512 },
208+
209+ # Different batch sizes with randomization and block sizes
210+ {'seqlen_q' : 0 , 'seqlen_k' : 1024 , 'batch_size' : 2 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (1 , 2 ), 'random_spec_len' : True , 'paged_kv_block_size' : 16 },
211+ {'seqlen_q' : 0 , 'seqlen_k' : 1024 , 'batch_size' : 16 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (2 , 3 ), 'random_seq_len' : True , 'paged_kv_block_size' : 256 },
212+
213+ # High spec_len scenarios with different block sizes
214+ {'seqlen_q' : 0 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (4 , 5 ), 'paged_kv_block_size' : 256 },
215+ {'seqlen_q' : 0 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (6 , 8 ), 'paged_kv_block_size' : 512 },
216+
217+ # Block size comparison scenarios
218+ {'seqlen_q' : 0 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (2 , 3 ), 'paged_kv_block_size' : 16 },
219+ {'seqlen_q' : 0 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (2 , 3 ), 'paged_kv_block_size' : 256 },
220+ {'seqlen_q' : 0 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (2 , 3 ), 'paged_kv_block_size' : 512 },
221+ ]
222+
223+ print ("=" * 80 )
224+ print ("DECODING BENCHMARK (seqlen_q=0)" )
225+ print ("=" * 80 )
226+ print ("This benchmark represents the decoding scenario where tree attention" )
227+ print ("can be compared against batch expansion for generation tasks." )
228+ print ("=" * 80 )
229+
230+ results = []
231+ for i , config in enumerate (configs ):
232+ print (f"\n [{ i + 1 } /{ len (configs )} ] Decoding Configuration:" )
233+ result = run_tree_attention_benchmark (** config )
234+ results .append (result )
235+ print ("-" * 80 )
236+
237+ # Summary
238+ print ("\n " + "=" * 80 )
239+ print ("DECODING BENCHMARK SUMMARY" )
240+ print ("=" * 80 )
241+ print (f"{ 'Config' :<18} { 'Tree(ms)' :<10} { 'Varlen(ms)' :<12} { 'Speedup' :<10} { 'Max Diff' :<12} " )
242+ print ("-" * 80 )
243+
244+ for i , result in enumerate (results ):
245+ config = result ['config' ]
246+ config_str = f"{ config ['seqlen_q' ]} :{ config ['seqlen_k' ]} :{ config ['tree_mask.shape' ][0 ]} :{ config ['paged_kv_block_size' ]} "
247+ tree_ms = result ['tree_time' ] * 1000
248+ varlen_ms = result ['varlen_time' ] * 1000
249+ speedup = result ['speedup' ]
250+ max_diff = result ['max_diff' ]
251+
252+ print (f"{ config_str :<18} { tree_ms :<10.3f} { varlen_ms :<12.3f} { speedup :<10.2f} x { max_diff :<12.6f} " )
253+
254+ return results
255+
256+
257+ def run_comprehensive_benchmark ():
258+ """Run benchmarks across different configurations."""
259+ configs = [
260+ # Small sequences with different spec_len and block sizes
261+ {'seqlen_q' : 128 , 'seqlen_k' : 128 , 'batch_size' : 4 , 'nheads' : 8 , 'head_dim' : 128 , 'spec_len' : (1 , 2 ), 'paged_kv_block_size' : 16 },
262+ {'seqlen_q' : 256 , 'seqlen_k' : 256 , 'batch_size' : 4 , 'nheads' : 8 , 'head_dim' : 128 , 'spec_len' : (2 , 3 ), 'paged_kv_block_size' : 16 },
263+
264+ # Medium sequences with varied spec_len and block sizes
265+ {'seqlen_q' : 512 , 'seqlen_k' : 512 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (1 , 2 ), 'paged_kv_block_size' : 256 },
266+ {'seqlen_q' : 1024 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (3 , 4 ), 'paged_kv_block_size' : 256 },
267+
268+ # Large sequences with larger block sizes
269+ {'seqlen_q' : 2048 , 'seqlen_k' : 2048 , 'batch_size' : 4 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (2 , 3 ), 'paged_kv_block_size' : 512 },
270+
271+ # Different head dimensions with varied block sizes
272+ {'seqlen_q' : 1024 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 64 , 'spec_len' : (1 , 2 ), 'paged_kv_block_size' : 256 },
273+ {'seqlen_q' : 1024 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 256 , 'spec_len' : (2 , 3 ), 'paged_kv_block_size' : 512 },
274+
275+ # Different batch sizes with randomization and block sizes
276+ {'seqlen_q' : 1024 , 'seqlen_k' : 1024 , 'batch_size' : 2 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (1 , 2 ), 'random_spec_len' : True , 'paged_kv_block_size' : 16 },
277+ {'seqlen_q' : 1024 , 'seqlen_k' : 1024 , 'batch_size' : 16 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (2 , 3 ), 'random_seq_len' : True , 'paged_kv_block_size' : 256 },
278+
279+ # High spec_len scenarios with different block sizes
280+ {'seqlen_q' : 1024 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (4 , 5 ), 'paged_kv_block_size' : 256 },
281+ {'seqlen_q' : 1024 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (6 , 8 ), 'paged_kv_block_size' : 512 },
282+
283+ # Mixed randomization scenarios with block sizes
284+ {'seqlen_q' : 512 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (2 , 3 ), 'random_seq_len' : True , 'random_spec_len' : True , 'paged_kv_block_size' : 256 },
285+
286+ # Block size comparison scenarios
287+ {'seqlen_q' : 1024 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (2 , 3 ), 'paged_kv_block_size' : 16 },
288+ {'seqlen_q' : 1024 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (2 , 3 ), 'paged_kv_block_size' : 256 },
289+ {'seqlen_q' : 1024 , 'seqlen_k' : 1024 , 'batch_size' : 8 , 'nheads' : 16 , 'head_dim' : 128 , 'spec_len' : (2 , 3 ), 'paged_kv_block_size' : 512 },
290+ ]
291+
292+ print ("=" * 80 )
293+ print ("COMPREHENSIVE TREE ATTENTION BENCHMARK" )
294+ print ("=" * 80 )
295+
296+ results = []
297+ for i , config in enumerate (configs ):
298+ print (f"\n [{ i + 1 } /{ len (configs )} ] Configuration:" )
299+ result = run_tree_attention_benchmark (** config )
300+ results .append (result )
301+ print ("-" * 80 )
302+
303+ # Summary
304+ print ("\n " + "=" * 80 )
305+ print ("BENCHMARK SUMMARY" )
306+ print ("=" * 80 )
307+ print (f"{ 'Config' :<18} { 'Tree(ms)' :<10} { 'Varlen(ms)' :<12} { 'Speedup' :<10} { 'Max Diff' :<12} " )
308+ print ("-" * 80 )
309+
310+ for i , result in enumerate (results ):
311+ config = result ['config' ]
312+ config_str = f"{ config ['seqlen_q' ]} :{ config ['seqlen_k' ]} :{ config ['tree_mask.shape' ][0 ]} :{ config ['paged_kv_block_size' ]} "
313+ tree_ms = result ['tree_time' ] * 1000
314+ varlen_ms = result ['varlen_time' ] * 1000
315+ speedup = result ['speedup' ]
316+ max_diff = result ['max_diff' ]
317+
318+ print (f"{ config_str :<18} { tree_ms :<10.3f} { varlen_ms :<12.3f} { speedup :<10.2f} x { max_diff :<12.6f} " )
319+
320+ return results
321+
322+
323+ if __name__ == "__main__" :
324+ if not torch .cuda .is_available ():
325+ print ("CUDA is not available. This benchmark requires GPU." )
326+ exit (1 )
327+
328+ print ("Tree Attention vs Flash Attention Varlen Benchmark" )
329+ print (f"PyTorch version: { torch .__version__ } " )
330+ print (f"CUDA version: { torch .version .cuda } " )
331+ print (f"Device: { torch .cuda .get_device_name ()} " )
332+
333+ # Run single benchmark
334+ print ("\n " + "=" * 80 )
335+ print ("SINGLE BENCHMARK (1024x1024, batch=8)" )
336+ print ("=" * 80 )
337+ run_tree_attention_benchmark ()
338+
339+ # Run decoding benchmark
340+ run_decoding_benchmark ()
341+
342+ # Run comprehensive benchmark
343+ run_comprehensive_benchmark ()
0 commit comments