Skip to content

Commit f346cfa

Browse files
authored
Merge pull request #6 from tomasruizt/feature/correct-tensor-parallelism-on-draft-model
Feature/correct tensor parallelism on draft model
2 parents 37f013e + e1dbab1 commit f346cfa

File tree

5 files changed

+150
-29
lines changed

5 files changed

+150
-29
lines changed

tests/v1/e2e/test_spec_decode.py

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
from vllm import LLM, SamplingParams
1212
from vllm.assets.base import VLLM_S3_BUCKET_URL
1313
from vllm.assets.image import VLM_IMAGES_DIR
14+
from vllm.config.vllm import VllmConfig
1415
from vllm.distributed import cleanup_dist_env_and_memory
16+
from vllm.engine.arg_utils import EngineArgs
1517
from vllm.outputs import RequestOutput
1618
from vllm.platforms import current_platform
19+
from vllm.v1.spec_decode.draft_model import create_vllm_config_for_draft_model
1720
from vllm.v1.spec_decode.metrics import compute_acceptance_len, compute_acceptance_rate
1821

1922
MTP_SIMILARITY_RATE = 0.8
@@ -359,7 +362,7 @@ def test_mtp_correctness(
359362

360363
@dataclass
361364
class ArgsTest:
362-
model: str
365+
target_model: str
363366
draft_model: str
364367
sampling_config: SamplingParams
365368
num_speculative_tokens: int
@@ -376,7 +379,7 @@ class ArgsTest:
376379
cases = [
377380
# Same model for draft and target, greedy sampling.
378381
ArgsTest(
379-
model="Qwen/Qwen3-0.6B",
382+
target_model="Qwen/Qwen3-0.6B",
380383
draft_model="Qwen/Qwen3-0.6B",
381384
sampling_config=greedy_sampling(),
382385
num_speculative_tokens=3, # K
@@ -386,7 +389,7 @@ class ArgsTest:
386389
),
387390
# Smaller draft model, stochastic sampling.
388391
ArgsTest(
389-
model="Qwen/Qwen3-1.7B",
392+
target_model="Qwen/Qwen3-1.7B",
390393
draft_model="Qwen/Qwen3-0.6B",
391394
sampling_config=stochastic_sampling(),
392395
num_speculative_tokens=3,
@@ -416,31 +419,80 @@ def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
416419
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
417420
tgt_model, draft_model = models
418421
sd_case = ArgsTest(
419-
model=tgt_model,
422+
target_model=tgt_model,
420423
draft_model=draft_model,
421-
sampling_config=greedy_sampling(),
422-
num_speculative_tokens=3,
423-
expected_acceptance_len=2.95 + 1,
424-
expected_acceptance_rate=0.95,
425-
expected_same_output_fraction=0.95,
424+
**some_high_acceptance_metrics(),
426425
)
427426
assert_draft_model_correctness(sd_case, enforce_eager)
428427

429428

429+
def test_draft_model_tensor_parallelism():
430+
"""Ensure spec decode works when running with TP > 1."""
431+
sd_case = ArgsTest(
432+
target_model="Qwen/Qwen3-1.7B",
433+
target_tensor_parallel_size=2,
434+
draft_model="Qwen/Qwen3-0.6B",
435+
draft_tensor_parallel_size=2,
436+
**some_high_acceptance_metrics(),
437+
)
438+
assert_draft_model_correctness(sd_case, enforce_eager=False)
439+
440+
441+
def test_draft_model_engine_args_tensor_parallelism():
442+
"""Ensure the vllm_config for the draft model is created correctly,
443+
and independently of the target model (quantization, TP, etc.)"""
444+
445+
engine_args = EngineArgs(
446+
model="Qwen/Qwen3-1.7B-FP8", # <<< tgt quantized
447+
tensor_parallel_size=4,
448+
speculative_config={
449+
"model": "Qwen/Qwen3-0.6B", # <<< draft not quantized
450+
"method": "draft_model",
451+
"num_speculative_tokens": 3,
452+
"draft_tensor_parallel_size": 1, # <<< valid arg name
453+
},
454+
)
455+
tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
456+
assert tgt_vllm_config.parallel_config.tensor_parallel_size == 4
457+
assert tgt_vllm_config.quant_config.get_name() == "fp8"
458+
459+
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
460+
assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
461+
assert draft_vllm_config.quant_config is None
462+
463+
464+
def test_draft_model_engine_args_rejects_invalid_tp_argname():
465+
"""The user should pass "draft_tensor_parallel_size" rather than
466+
"tensor_parallel_size". We enforce this with validation."""
467+
468+
engine_args = EngineArgs(
469+
model="Qwen/Qwen3-1.7B",
470+
tensor_parallel_size=1,
471+
speculative_config={
472+
"model": "Qwen/Qwen3-0.6B",
473+
"method": "draft_model",
474+
"num_speculative_tokens": 3,
475+
"tensor_parallel_size": 1, # <<< invalid arg name
476+
},
477+
)
478+
with pytest.raises(ValueError):
479+
engine_args.create_engine_config()
480+
481+
430482
def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
431483
"""Compare the outputs using and not using speculative decoding.
432484
In the greedy decoding case, the outputs must match EXACTLY."""
433485
test_prompts = get_test_prompts(mm_enabled=False, quiet=True)
434486

435487
spec_llm = LLM(
436-
model=args.model,
488+
model=args.target_model,
437489
speculative_config={
438490
"model": args.draft_model,
439491
"method": "draft_model",
440492
"num_speculative_tokens": args.num_speculative_tokens,
441493
"max_model_len": args.max_model_len,
442494
"enforce_eager": enforce_eager,
443-
"tensor_parallel_size": args.draft_tensor_parallel_size,
495+
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
444496
"disable_padded_drafter_batch": True,
445497
"max_num_seqs": 100, # limit cudagraph capture runtime
446498
},
@@ -462,7 +514,7 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
462514
assert acceptance_len >= args.expected_acceptance_len
463515

464516
ref_llm = LLM(
465-
model=args.model,
517+
model=args.target_model,
466518
max_model_len=args.max_model_len,
467519
gpu_memory_utilization=args.gpu_memory_utilization,
468520
tensor_parallel_size=args.target_tensor_parallel_size,
@@ -480,7 +532,7 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
480532
assert match_fraction >= args.expected_same_output_fraction
481533

482534
print(
483-
f"spec-decode: target={args.model}, draft={args.draft_model}, "
535+
f"spec-decode: target={args.target_model}, draft={args.draft_model}, "
484536
f"temperature={args.sampling_config.temperature:.2f}, "
485537
f"acceptance_rate={acceptance_rate:.2f}, "
486538
f"acceptance_len={acceptance_len:.2f}, "
@@ -501,3 +553,13 @@ def compute_exact_matches(
501553
print(f"ref_output: {ref_output.outputs[0].text}")
502554
print(f"spec_output: {spec_output.outputs[0].text}")
503555
return matches / len(ref_outputs)
556+
557+
558+
def some_high_acceptance_metrics() -> dict:
559+
return {
560+
"sampling_config": greedy_sampling(),
561+
"num_speculative_tokens": 3,
562+
"expected_acceptance_len": 2.95 + 1,
563+
"expected_acceptance_rate": 0.95,
564+
"expected_same_output_fraction": 0.95,
565+
}

vllm/benchmarks/lib/ready_checker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
import aiohttp
99
from tqdm.asyncio import tqdm
1010

11+
from vllm.logger import init_logger
12+
1113
from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput
1214

15+
logger = init_logger(__name__)
16+
1317

1418
async def wait_for_endpoint(
1519
request_func: RequestFunc,
@@ -61,6 +65,8 @@ async def wait_for_endpoint(
6165
if output.success:
6266
pbar.close()
6367
return output
68+
else:
69+
logger.warning("Endpoint is not ready. Error='%s'", output.error)
6470
except aiohttp.ClientConnectorError:
6571
pass
6672

vllm/config/parallel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import hashlib
55
import os
6+
from dataclasses import replace
67
from typing import TYPE_CHECKING, Any, Literal
78

89
import torch
@@ -564,3 +565,6 @@ def _verify_args(self) -> Self:
564565
)
565566

566567
return self
568+
569+
def replace(self, **kwargs) -> Self:
570+
return replace(self, **kwargs)

vllm/config/speculative.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ class SpeculativeConfig:
7979
draft_tensor_parallel_size: int | None = None
8080
"""The degree of the tensor parallelism for the draft model. Can only be 1
8181
or the same as the target model's tensor parallel size."""
82+
tensor_parallel_size: int | None = None
83+
"""Users should pass "draft_tensor_parallel_size". This parameters is only
84+
to reject it if passed."""
85+
8286
disable_logprobs: bool = True
8387
"""If set to True, token log probabilities are not returned during
8488
speculative decoding. If set to False, token log probabilities are returned
@@ -537,6 +541,12 @@ def create_draft_parallel_config(
537541

538542
@model_validator(mode="after")
539543
def _verify_args(self) -> Self:
544+
if self.tensor_parallel_size is not None:
545+
raise ValueError(
546+
"'tensor_parallel_size' is not a valid argument in the "
547+
"speculative_config. Please pass 'draft_tensor_parallel_size' instead."
548+
)
549+
540550
if self.num_speculative_tokens is None:
541551
raise ValueError(
542552
"num_speculative_tokens must be provided with "

vllm/v1/spec_decode/draft_model.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import torch
77

88
from vllm.attention.layer import Attention
9-
from vllm.config import ModelConfig, VllmConfig, get_layers_from_vllm_config
9+
from vllm.config import VllmConfig, get_layers_from_vllm_config
10+
from vllm.config.speculative import SpeculativeConfig
11+
from vllm.logger import init_logger
1012
from vllm.model_executor.model_loader import get_model
1113
from vllm.v1.attention.backends.utils import (
1214
CommonAttentionMetadata,
@@ -16,6 +18,8 @@
1618
from vllm.v1.sample.metadata import SamplingMetadata
1719
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, SpecDecodeBaseProposer
1820

21+
logger = init_logger(__name__)
22+
1923

2024
class DraftModelProposer(SpecDecodeBaseProposer):
2125
def __init__(
@@ -34,6 +38,7 @@ def __init__(
3438
self._raise_if_mrope()
3539
self._raise_if_padded_drafter_batch()
3640
self._raise_if_vocab_size_mismatch()
41+
self._raise_if_draft_tp_mismatch()
3742

3843
def propose(
3944
self,
@@ -98,12 +103,29 @@ def _raise_if_padded_drafter_batch(self):
98103
raise NotImplementedError(
99104
"Speculative Decoding with draft models does not support "
100105
"padded drafter batch yet. Please pass --disable-padded-drafter-batch "
101-
"in the speculative config."
106+
"in the speculative_config."
102107
)
103108

104109
def _raise_if_vocab_size_mismatch(self):
105110
self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model()
106111

112+
def _raise_if_draft_tp_mismatch(self):
113+
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
114+
# the draft model with TP = 1, then the different TP ranks collide.
115+
# Specifically when all ranks compile the draft model on rank 0
116+
# (because TP=1), then the torch compile cache is overwritten and corrupted.
117+
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
118+
# To prevent this error, we assert that both TP sizes must be the same.
119+
spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config
120+
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
121+
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
122+
if draft_tp != tgt_tp:
123+
raise ValueError(
124+
f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
125+
f"must be the same. Got {draft_tp} and {tgt_tp}. "
126+
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
127+
)
128+
107129
def set_input_ids_first_pass(
108130
self,
109131
target_token_ids: torch.Tensor,
@@ -115,15 +137,6 @@ def set_input_ids_first_pass(
115137

116138
def load_model(self, target_model: Any) -> None:
117139
"""Takes target_model to satisfy the type checker."""
118-
draft_model_config: ModelConfig = (
119-
self.vllm_config.speculative_config.draft_model_config
120-
)
121-
# Recompute quant_config, which is configured for the target model
122-
# But the draft model might not be quantized.
123-
vllm_config_draft: VllmConfig = self.vllm_config.replace(
124-
quant_config=None,
125-
model_config=draft_model_config,
126-
)
127140

128141
# This must be computed before loading the draft model
129142
# because that mutates the forward_context of the vllm_config
@@ -133,12 +146,17 @@ def load_model(self, target_model: Any) -> None:
133146

134147
from vllm.compilation.backends import set_model_tag
135148

149+
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(
150+
target_model_vllm_config=self.vllm_config
151+
)
152+
logger.info(
153+
"Starting to load draft model %s. TP=%d, rank=%d",
154+
draft_vllm_config.model_config.model,
155+
draft_vllm_config.parallel_config.tensor_parallel_size,
156+
draft_vllm_config.parallel_config.rank,
157+
)
136158
with set_model_tag("draft_model"):
137-
self.model = get_model(
138-
vllm_config=vllm_config_draft,
139-
model_config=draft_model_config,
140-
prefix="draft_model",
141-
)
159+
self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model")
142160

143161
# This must be computed after loading the draft model
144162
# because that mutates the forward_context of the vllm_config
@@ -149,6 +167,27 @@ def load_model(self, target_model: Any) -> None:
149167
self.attn_layer_names = list(draft_attn_layer_names)
150168

151169

170+
def create_vllm_config_for_draft_model(
171+
target_model_vllm_config: VllmConfig,
172+
) -> VllmConfig:
173+
"""The vllm_config is configured for the target model, e.g.
174+
its quant_config and parallel_config. But the draft model is potentially
175+
quantized differently, and has potentially different tensor_parallel_size.
176+
This function creates a new vllm_config configured for the draft model.
177+
The vllm_config is useful when loading the draft model with get_model().
178+
"""
179+
old = target_model_vllm_config
180+
new_parallel_config = old.speculative_config.draft_parallel_config.replace(
181+
rank=old.parallel_config.rank
182+
)
183+
new: VllmConfig = old.replace(
184+
quant_config=None, # quant_config is recomputed in __init__()
185+
model_config=old.speculative_config.draft_model_config,
186+
parallel_config=new_parallel_config,
187+
)
188+
return new
189+
190+
152191
@dataclass
153192
class DraftModelInputs:
154193
token_ids: torch.Tensor

0 commit comments

Comments
 (0)