1111from vllm import LLM , SamplingParams
1212from vllm .assets .base import VLLM_S3_BUCKET_URL
1313from vllm .assets .image import VLM_IMAGES_DIR
14+ from vllm .config .vllm import VllmConfig
1415from vllm .distributed import cleanup_dist_env_and_memory
16+ from vllm .engine .arg_utils import EngineArgs
1517from vllm .outputs import RequestOutput
1618from vllm .platforms import current_platform
19+ from vllm .v1 .spec_decode .draft_model import create_vllm_config_for_draft_model
1720from vllm .v1 .spec_decode .metrics import compute_acceptance_len , compute_acceptance_rate
1821
1922MTP_SIMILARITY_RATE = 0.8
@@ -359,7 +362,7 @@ def test_mtp_correctness(
359362
360363@dataclass
361364class ArgsTest :
362- model : str
365+ target_model : str
363366 draft_model : str
364367 sampling_config : SamplingParams
365368 num_speculative_tokens : int
@@ -376,7 +379,7 @@ class ArgsTest:
376379cases = [
377380 # Same model for draft and target, greedy sampling.
378381 ArgsTest (
379- model = "Qwen/Qwen3-0.6B" ,
382+ target_model = "Qwen/Qwen3-0.6B" ,
380383 draft_model = "Qwen/Qwen3-0.6B" ,
381384 sampling_config = greedy_sampling (),
382385 num_speculative_tokens = 3 , # K
@@ -386,7 +389,7 @@ class ArgsTest:
386389 ),
387390 # Smaller draft model, stochastic sampling.
388391 ArgsTest (
389- model = "Qwen/Qwen3-1.7B" ,
392+ target_model = "Qwen/Qwen3-1.7B" ,
390393 draft_model = "Qwen/Qwen3-0.6B" ,
391394 sampling_config = stochastic_sampling (),
392395 num_speculative_tokens = 3 ,
@@ -416,31 +419,80 @@ def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
416419def test_draft_model_quantization (models : tuple [str , str ], enforce_eager : bool ):
417420 tgt_model , draft_model = models
418421 sd_case = ArgsTest (
419- model = tgt_model ,
422+ target_model = tgt_model ,
420423 draft_model = draft_model ,
421- sampling_config = greedy_sampling (),
422- num_speculative_tokens = 3 ,
423- expected_acceptance_len = 2.95 + 1 ,
424- expected_acceptance_rate = 0.95 ,
425- expected_same_output_fraction = 0.95 ,
424+ ** some_high_acceptance_metrics (),
426425 )
427426 assert_draft_model_correctness (sd_case , enforce_eager )
428427
429428
429+ def test_draft_model_tensor_parallelism ():
430+ """Ensure spec decode works when running with TP > 1."""
431+ sd_case = ArgsTest (
432+ target_model = "Qwen/Qwen3-1.7B" ,
433+ target_tensor_parallel_size = 2 ,
434+ draft_model = "Qwen/Qwen3-0.6B" ,
435+ draft_tensor_parallel_size = 2 ,
436+ ** some_high_acceptance_metrics (),
437+ )
438+ assert_draft_model_correctness (sd_case , enforce_eager = False )
439+
440+
441+ def test_draft_model_engine_args_tensor_parallelism ():
442+ """Ensure the vllm_config for the draft model is created correctly,
443+ and independently of the target model (quantization, TP, etc.)"""
444+
445+ engine_args = EngineArgs (
446+ model = "Qwen/Qwen3-1.7B-FP8" , # <<< tgt quantized
447+ tensor_parallel_size = 4 ,
448+ speculative_config = {
449+ "model" : "Qwen/Qwen3-0.6B" , # <<< draft not quantized
450+ "method" : "draft_model" ,
451+ "num_speculative_tokens" : 3 ,
452+ "draft_tensor_parallel_size" : 1 , # <<< valid arg name
453+ },
454+ )
455+ tgt_vllm_config : VllmConfig = engine_args .create_engine_config ()
456+ assert tgt_vllm_config .parallel_config .tensor_parallel_size == 4
457+ assert tgt_vllm_config .quant_config .get_name () == "fp8"
458+
459+ draft_vllm_config : VllmConfig = create_vllm_config_for_draft_model (tgt_vllm_config )
460+ assert draft_vllm_config .parallel_config .tensor_parallel_size == 1
461+ assert draft_vllm_config .quant_config is None
462+
463+
464+ def test_draft_model_engine_args_rejects_invalid_tp_argname ():
465+ """The user should pass "draft_tensor_parallel_size" rather than
466+ "tensor_parallel_size". We enforce this with validation."""
467+
468+ engine_args = EngineArgs (
469+ model = "Qwen/Qwen3-1.7B" ,
470+ tensor_parallel_size = 1 ,
471+ speculative_config = {
472+ "model" : "Qwen/Qwen3-0.6B" ,
473+ "method" : "draft_model" ,
474+ "num_speculative_tokens" : 3 ,
475+ "tensor_parallel_size" : 1 , # <<< invalid arg name
476+ },
477+ )
478+ with pytest .raises (ValueError ):
479+ engine_args .create_engine_config ()
480+
481+
430482def assert_draft_model_correctness (args : ArgsTest , enforce_eager : bool ):
431483 """Compare the outputs using and not using speculative decoding.
432484 In the greedy decoding case, the outputs must match EXACTLY."""
433485 test_prompts = get_test_prompts (mm_enabled = False , quiet = True )
434486
435487 spec_llm = LLM (
436- model = args .model ,
488+ model = args .target_model ,
437489 speculative_config = {
438490 "model" : args .draft_model ,
439491 "method" : "draft_model" ,
440492 "num_speculative_tokens" : args .num_speculative_tokens ,
441493 "max_model_len" : args .max_model_len ,
442494 "enforce_eager" : enforce_eager ,
443- "tensor_parallel_size " : args .draft_tensor_parallel_size ,
495+ "draft_tensor_parallel_size " : args .draft_tensor_parallel_size ,
444496 "disable_padded_drafter_batch" : True ,
445497 "max_num_seqs" : 100 , # limit cudagraph capture runtime
446498 },
@@ -462,7 +514,7 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
462514 assert acceptance_len >= args .expected_acceptance_len
463515
464516 ref_llm = LLM (
465- model = args .model ,
517+ model = args .target_model ,
466518 max_model_len = args .max_model_len ,
467519 gpu_memory_utilization = args .gpu_memory_utilization ,
468520 tensor_parallel_size = args .target_tensor_parallel_size ,
@@ -480,7 +532,7 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
480532 assert match_fraction >= args .expected_same_output_fraction
481533
482534 print (
483- f"spec-decode: target={ args .model } , draft={ args .draft_model } , "
535+ f"spec-decode: target={ args .target_model } , draft={ args .draft_model } , "
484536 f"temperature={ args .sampling_config .temperature :.2f} , "
485537 f"acceptance_rate={ acceptance_rate :.2f} , "
486538 f"acceptance_len={ acceptance_len :.2f} , "
@@ -501,3 +553,13 @@ def compute_exact_matches(
501553 print (f"ref_output: { ref_output .outputs [0 ].text } " )
502554 print (f"spec_output: { spec_output .outputs [0 ].text } " )
503555 return matches / len (ref_outputs )
556+
557+
558+ def some_high_acceptance_metrics () -> dict :
559+ return {
560+ "sampling_config" : greedy_sampling (),
561+ "num_speculative_tokens" : 3 ,
562+ "expected_acceptance_len" : 2.95 + 1 ,
563+ "expected_acceptance_rate" : 0.95 ,
564+ "expected_same_output_fraction" : 0.95 ,
565+ }
0 commit comments