Skip to content

Commit cb1447f

Browse files
[Bug fix] vLLM upstream compatibility. Fix DP scheduler (#1057)
Signed-off-by: wenxindongwork <wenxindong@google.com>
1 parent 6001414 commit cb1447f

File tree

2 files changed

+62
-61
lines changed

2 files changed

+62
-61
lines changed

tests/core/test_dp_scheduler.py

Lines changed: 53 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from vllm.config import VllmConfig
66
from vllm.v1.core.sched.output import (CachedRequestData, GrammarOutput,
77
SchedulerOutput)
8+
from vllm.v1.core.sched.scheduler import Scheduler
89
from vllm.v1.engine import EngineCoreOutputs
910
from vllm.v1.kv_cache_interface import KVCacheConfig
1011
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
@@ -24,9 +25,10 @@ def mock_vllm_config(self):
2425
config.sharding_config = MagicMock()
2526
config.sharding_config.total_dp_size = 2
2627
config.scheduler_config = MagicMock()
27-
config.scheduler_config._original_scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
28+
config.scheduler_config._original_scheduler_cls = Scheduler
2829
config.scheduler_config.max_num_seqs = 8
2930
config.scheduler_config.max_num_batched_tokens = 1024
31+
config.scheduler_config.async_scheduling = False
3032
return config
3133

3234
@pytest.fixture
@@ -46,18 +48,14 @@ def _create_dp_scheduler_with_mocks(self, mock_vllm_config,
4648
mock_structured_output_manager,
4749
**kwargs):
4850
"""Helper to create a DPScheduler with properly mocked schedulers."""
49-
with patch(
50-
"tpu_inference.core.sched.dp_scheduler.resolve_obj_by_qualname"
51-
) as mock_resolve:
52-
# Create individual mock scheduler instances
53-
mock_scheduler_0 = MagicMock()
54-
mock_scheduler_1 = MagicMock()
55-
56-
# Set up the mock class to return these instances
57-
mock_scheduler_cls = MagicMock(
58-
side_effect=[mock_scheduler_0, mock_scheduler_1])
59-
mock_resolve.return_value = mock_scheduler_cls
60-
51+
# Create individual mock scheduler instances
52+
mock_scheduler_0 = MagicMock()
53+
mock_scheduler_1 = MagicMock()
54+
55+
# Patch the Scheduler class to return our mock instances
56+
with patch.object(
57+
mock_vllm_config.scheduler_config, '_original_scheduler_cls',
58+
MagicMock(side_effect=[mock_scheduler_0, mock_scheduler_1])):
6159
scheduler = DPScheduler(
6260
vllm_config=mock_vllm_config,
6361
kv_cache_config=mock_kv_cache_config,
@@ -67,38 +65,36 @@ def _create_dp_scheduler_with_mocks(self, mock_vllm_config,
6765

6866
return scheduler
6967

70-
@patch("tpu_inference.core.sched.dp_scheduler.resolve_obj_by_qualname")
7168
def test_init_creates_per_rank_schedulers(
7269
self,
73-
mock_resolve,
7470
mock_vllm_config,
7571
mock_kv_cache_config,
7672
mock_structured_output_manager,
7773
):
7874
"""Test Initialization creates schedulers for each DP rank."""
7975
# Mock the scheduler class
80-
mock_scheduler_cls = MagicMock()
8176
mock_scheduler_instance = MagicMock()
82-
mock_scheduler_cls.return_value = mock_scheduler_instance
83-
mock_resolve.return_value = mock_scheduler_cls
84-
85-
scheduler = DPScheduler(
86-
vllm_config=mock_vllm_config,
87-
kv_cache_config=mock_kv_cache_config,
88-
structured_output_manager=mock_structured_output_manager,
89-
block_size=16,
90-
log_stats=True,
91-
)
77+
mock_scheduler_cls = MagicMock(return_value=mock_scheduler_instance)
78+
79+
with patch.object(mock_vllm_config.scheduler_config,
80+
'_original_scheduler_cls', mock_scheduler_cls):
81+
scheduler = DPScheduler(
82+
vllm_config=mock_vllm_config,
83+
kv_cache_config=mock_kv_cache_config,
84+
structured_output_manager=mock_structured_output_manager,
85+
block_size=16,
86+
log_stats=True,
87+
)
9288

93-
# Verify schedulers were created
94-
assert len(scheduler.schedulers) == 2
95-
assert scheduler.dp_size == 2
96-
assert scheduler.log_stats is True
97-
assert len(scheduler.per_rank_kv_cache_configs) == 2
89+
# Verify schedulers were created
90+
assert len(scheduler.schedulers) == 2
91+
assert scheduler.dp_size == 2
92+
assert scheduler.log_stats is True
93+
assert len(scheduler.per_rank_kv_cache_configs) == 2
9894

99-
# Verify each rank got the correct config
100-
for rank_config in scheduler.per_rank_kv_cache_configs:
101-
assert rank_config.num_blocks == 50 # 100 / 2
95+
# Verify each rank got the correct config
96+
for rank_config in scheduler.per_rank_kv_cache_configs:
97+
assert rank_config.num_blocks == 50 # 100 / 2
10298

10399
def test_get_rank_token_counts(self, mock_vllm_config,
104100
mock_kv_cache_config,
@@ -296,9 +292,9 @@ def test_combine_cached_request_data(self, mock_vllm_config,
296292
mock_kv_cache_config,
297293
mock_structured_output_manager):
298294
"""Test _combine_cached_request_data combines data from all ranks."""
299-
with patch(
300-
"tpu_inference.core.sched.dp_scheduler.resolve_obj_by_qualname"
301-
):
295+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
296+
with patch.object(mock_vllm_config.scheduler_config,
297+
'_original_scheduler_cls', mock_scheduler_cls):
302298
scheduler = DPScheduler(
303299
vllm_config=mock_vllm_config,
304300
kv_cache_config=mock_kv_cache_config,
@@ -403,9 +399,9 @@ def test_get_grammar_bitmask_no_structured_output(
403399
self, mock_vllm_config, mock_kv_cache_config,
404400
mock_structured_output_manager):
405401
"""Test get_grammar_bitmask returns None when no structured output."""
406-
with patch(
407-
"tpu_inference.core.sched.dp_scheduler.resolve_obj_by_qualname"
408-
):
402+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
403+
with patch.object(mock_vllm_config.scheduler_config,
404+
'_original_scheduler_cls', mock_scheduler_cls):
409405
scheduler = DPScheduler(
410406
vllm_config=mock_vllm_config,
411407
kv_cache_config=mock_kv_cache_config,
@@ -452,9 +448,9 @@ def test_update_from_output_routes_to_schedulers(
452448
self, mock_vllm_config, mock_kv_cache_config,
453449
mock_structured_output_manager):
454450
"""Test update_from_output splits output and updates each scheduler."""
455-
with patch(
456-
"tpu_inference.core.sched.dp_scheduler.resolve_obj_by_qualname"
457-
):
451+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
452+
with patch.object(mock_vllm_config.scheduler_config,
453+
'_original_scheduler_cls', mock_scheduler_cls):
458454
scheduler = DPScheduler(
459455
vllm_config=mock_vllm_config,
460456
kv_cache_config=mock_kv_cache_config,
@@ -551,9 +547,9 @@ def test_split_model_output_by_rank(self, mock_vllm_config,
551547
mock_kv_cache_config,
552548
mock_structured_output_manager):
553549
"""Test _split_model_output_by_rank distributes output correctly."""
554-
with patch(
555-
"tpu_inference.core.sched.dp_scheduler.resolve_obj_by_qualname"
556-
):
550+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
551+
with patch.object(mock_vllm_config.scheduler_config,
552+
'_original_scheduler_cls', mock_scheduler_cls):
557553
scheduler = DPScheduler(
558554
vllm_config=mock_vllm_config,
559555
kv_cache_config=mock_kv_cache_config,
@@ -597,9 +593,9 @@ def test_cleanup_finished_requests(self, mock_vllm_config,
597593
mock_kv_cache_config,
598594
mock_structured_output_manager):
599595
"""Test _cleanup_finished_requests removes finished requests."""
600-
with patch(
601-
"tpu_inference.core.sched.dp_scheduler.resolve_obj_by_qualname"
602-
):
596+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
597+
with patch.object(mock_vllm_config.scheduler_config,
598+
'_original_scheduler_cls', mock_scheduler_cls):
603599
scheduler = DPScheduler(
604600
vllm_config=mock_vllm_config,
605601
kv_cache_config=mock_kv_cache_config,
@@ -669,9 +665,9 @@ def test_has_finished_requests(self, mock_vllm_config,
669665
mock_kv_cache_config,
670666
mock_structured_output_manager):
671667
"""Test has_finished_requests checks all ranks."""
672-
with patch(
673-
"tpu_inference.core.sched.dp_scheduler.resolve_obj_by_qualname"
674-
):
668+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
669+
with patch.object(mock_vllm_config.scheduler_config,
670+
'_original_scheduler_cls', mock_scheduler_cls):
675671
scheduler = DPScheduler(
676672
vllm_config=mock_vllm_config,
677673
kv_cache_config=mock_kv_cache_config,
@@ -798,9 +794,9 @@ def test_make_stats_with_logging_disabled(self, mock_vllm_config,
798794
mock_kv_cache_config,
799795
mock_structured_output_manager):
800796
"""Test make_stats returns None when logging is disabled."""
801-
with patch(
802-
"tpu_inference.core.sched.dp_scheduler.resolve_obj_by_qualname"
803-
):
797+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
798+
with patch.object(mock_vllm_config.scheduler_config,
799+
'_original_scheduler_cls', mock_scheduler_cls):
804800
scheduler = DPScheduler(
805801
vllm_config=mock_vllm_config,
806802
kv_cache_config=mock_kv_cache_config,
@@ -878,11 +874,12 @@ def test_update_config_with_dp_size_greater_than_one(self):
878874
mock_config.sharding_config.total_dp_size = 2
879875
mock_config.scheduler_config._original_scheduler_cls = None
880876
mock_config.scheduler_config.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
877+
mock_config.scheduler_config.async_scheduling = False
881878

882879
update_vllm_config_for_dp_scheduler(mock_config)
883880

884881
# Verify config was updated
885-
assert mock_config.scheduler_config._original_scheduler_cls == "vllm.v1.core.sched.scheduler.Scheduler"
882+
assert mock_config.scheduler_config._original_scheduler_cls == Scheduler
886883
assert mock_config.scheduler_config.scheduler_cls == DPScheduler
887884

888885
def test_update_config_with_dp_size_one(self):

tpu_inference/core/sched/dp_scheduler.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from vllm.config import VllmConfig
88
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
9-
from vllm.utils.import_utils import resolve_obj_by_qualname
9+
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
1010
from vllm.v1.core.sched.interface import SchedulerInterface
1111
from vllm.v1.core.sched.output import (CachedRequestData, GrammarOutput,
1212
SchedulerOutput)
@@ -76,8 +76,7 @@ def __init__(
7676
self._create_per_rank_configs(kv_cache_config)
7777

7878
# The original scheduler class could be Scheduler or AsyncScheduler
79-
original_scheduler_cls = resolve_obj_by_qualname(
80-
vllm_config.scheduler_config._original_scheduler_cls)
79+
original_scheduler_cls = vllm_config.scheduler_config._original_scheduler_cls
8180
self.schedulers: List[Scheduler] = []
8281
for rank in range(self.dp_size):
8382
scheduler = original_scheduler_cls(
@@ -92,7 +91,8 @@ def __init__(
9291
self.schedulers.append(scheduler)
9392

9493
logger.info(
95-
f"DPScheduler per-rank limits: max_seqs={self.vllm_config.scheduler_config.max_num_seqs}, "
94+
f"DPScheduler (Async = {self.vllm_config.scheduler_config.async_scheduling}) "
95+
f"per-rank limits: max_seqs={self.vllm_config.scheduler_config.max_num_seqs}, "
9696
f"max_tokens={self.vllm_config.scheduler_config.max_num_batched_tokens}"
9797
)
9898

@@ -515,5 +515,9 @@ def update_vllm_config_for_dp_scheduler(vllm_config: Any) -> None:
515515
dp_size = vllm_config.sharding_config.total_dp_size
516516

517517
if dp_size > 1:
518-
vllm_config.scheduler_config._original_scheduler_cls = vllm_config.scheduler_config.scheduler_cls
518+
if vllm_config.scheduler_config.async_scheduling:
519+
vllm_config.scheduler_config._original_scheduler_cls = AsyncScheduler
520+
else:
521+
vllm_config.scheduler_config._original_scheduler_cls = Scheduler
522+
519523
vllm_config.scheduler_config.scheduler_cls = DPScheduler

0 commit comments

Comments
 (0)