@@ -159,10 +159,8 @@ def test_eagle_correctness(
159159 attn_backend : str ,
160160):
161161 if attn_backend == "TREE_ATTN" :
162- # TODO: Fix this flaky test
163162 pytest .skip (
164- "TREE_ATTN is flaky in the test disable for now until it can be "
165- "resolved (see https://github.com/vllm-project/vllm/issues/22922)" )
163+ "TREE_ATTN is tested separately in test_tree_eagle_correctness." )
166164
167165 # Generate test prompts inside the function instead of using fixture
168166 test_prompts = get_test_prompts (mm_enabled )
@@ -223,3 +221,83 @@ def test_eagle_correctness(
223221 del spec_llm
224222 torch .cuda .empty_cache ()
225223 cleanup_dist_env_and_memory ()
224+
225+
226+ @pytest .mark .parametrize ("model_setup" , [
227+ ("eagle" , "meta-llama/Llama-3.1-8B-Instruct" ,
228+ "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" , 1 ),
229+ ("eagle3" , "meta-llama/Llama-3.1-8B-Instruct" ,
230+ "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" , 1 ),
231+ ],
232+ ids = [
233+ "llama3_eagle" ,
234+ "llama3_eagle3" ,
235+ ])
236+ @pytest .mark .parametrize (
237+ "spec_token_tree" ,
238+ [
239+ [(0 , )], # A single token
240+ [(0 , ), (0 , 0 ), (0 , 0 , 0 )], # Chain
241+ [(0 , ), (1 , ), (2 , )], # Parallel
242+ [(0 , ), (1 , ), (2 , ), (0 , 0 ), (0 , 1 ), (1 , 0 ), (1 , 1 ), (2 , 0 ),
243+ (2 , 1 )], # Tree
244+ ])
245+ def test_tree_eagle_correctness (
246+ monkeypatch : pytest .MonkeyPatch ,
247+ sampling_config : SamplingParams ,
248+ model_setup : tuple [str , str , str , int ],
249+ spec_token_tree : list [tuple [int , ...]],
250+ ):
251+ # Generate test prompts inside the function instead of using fixture
252+ test_prompts = get_test_prompts (False )
253+ '''
254+ Compare the outputs of a original LLM and a speculative LLM
255+ should be the same when using eagle speculative decoding.
256+ model_setup: (method, model_name, eagle_model_name, tp_size)
257+ '''
258+ with monkeypatch .context () as m :
259+ m .setenv ("VLLM_USE_V1" , "1" )
260+ m .setenv ("VLLM_ATTENTION_BACKEND" , "TREE_ATTN" )
261+ method , model_name , spec_model_name , tp_size = model_setup
262+
263+ ref_llm = LLM (model = model_name ,
264+ max_model_len = 2048 ,
265+ tensor_parallel_size = tp_size )
266+ ref_outputs = ref_llm .chat (test_prompts , sampling_config )
267+ del ref_llm
268+ torch .cuda .empty_cache ()
269+ cleanup_dist_env_and_memory ()
270+
271+ spec_llm = LLM (
272+ model = model_name ,
273+ trust_remote_code = True ,
274+ tensor_parallel_size = tp_size ,
275+ speculative_config = {
276+ "method" : method ,
277+ "model" : spec_model_name ,
278+ "num_speculative_tokens" : len (spec_token_tree ),
279+ "spec_token_tree" : str (spec_token_tree ),
280+ "max_model_len" : 2048 ,
281+ },
282+ max_model_len = 2048 ,
283+ )
284+ spec_outputs = spec_llm .chat (test_prompts , sampling_config )
285+ matches = 0
286+ misses = 0
287+ for ref_output , spec_output in zip (ref_outputs , spec_outputs ):
288+ if ref_output .outputs [0 ].text == spec_output .outputs [0 ].text :
289+ matches += 1
290+ else :
291+ misses += 1
292+ print (f"ref_output: { ref_output .outputs [0 ].text } " )
293+ print (f"spec_output: { spec_output .outputs [0 ].text } " )
294+
295+ # Heuristic: expect at least 50% of the prompts to match exactly
296+ # Upon failure, inspect the outputs to check for inaccuracy. This
297+ # threshold is lower than the other tests because the tree attention
298+ # backend uses triton kernels, which seem to introduce more floating
299+ # point non-determinism when compared to FA3.
300+ assert matches > int (0.50 * len (ref_outputs ))
301+ del spec_llm
302+ torch .cuda .empty_cache ()
303+ cleanup_dist_env_and_memory ()
0 commit comments