From 8272787effe6a03c369459b79a70604f0e8cb9c2 Mon Sep 17 00:00:00 2001 From: GDzhu01 <809721801@qq.com> Date: Mon, 1 Dec 2025 09:11:45 +0800 Subject: [PATCH] support balance scheduling Signed-off-by: GDzhu01 <809721801@qq.com> --- vllm/config/scheduler.py | 7 +++++++ vllm/v1/core/sched/scheduler.py | 17 +++++++++++++++++ vllm/v1/engine/core.py | 3 +++ 3 files changed, 27 insertions(+) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index ff1ac0e18f32..b14bfc0e2b39 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -141,6 +141,13 @@ class SchedulerConfig: while a larger value (e.g., 10) reduces host overhead and may increase throughput by batching multiple tokens before sending.""" + balance_scheduling: bool = False + """EXPERIMENTAL: If set to True, perform balance scheduling. This may help + increase output throughput and reduce TPOT in v1 sheduler. However, TTFT + may degrade in some scenarios. Furthermore, enabling this feature is not + recommended in scenarios where PD is separated. + """ + def get_scheduler_cls(self) -> type["SchedulerInterface"]: if self.scheduler_cls is None: if self.async_scheduling: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index e3ec8440a932..517b5bc92ad6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -6,6 +6,8 @@ from collections.abc import Iterable from typing import Any +import torch +import torch.distributed as dist from vllm import envs from vllm.config import VllmConfig from vllm.distributed.ec_transfer.ec_connector.base import ( @@ -191,6 +193,16 @@ def __init__( self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER + # Balance scheduling. + if self.vllm_config.scheduler_config.balance_scheduling: + self.balance_queue = [ + torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(self.vllm_config.parallel_config.data_parallel_size) + ] + + def balance_gather(self, dp_group): + runing_tensor = torch.tensor([len(self.running)], dtype=torch.int, device="cpu") + dist.all_gather(self.balance_queue, runing_tensor, group=dp_group) + def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -407,6 +419,11 @@ def schedule(self) -> SchedulerOutput: if len(self.running) == self.max_num_running_reqs: break + if self.vllm_config.scheduler_config.balance_scheduling: + balance_flag = max(t.item() for t in self.balance_queue) == self.max_num_running_reqs + if balance_flag: + break + request = self.waiting.peek_request() # KVTransfer: skip request if still waiting for remote kvs. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e3a5f51a8fc5..c1b602885cbe 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1229,6 +1229,9 @@ def run_busy_loop(self): local_unfinished_reqs ) + if self.vllm_config.scheduler_config.balance_scheduling: + self.scheduler.balance_gather(self.dp_group) + if not self.engines_running: if self.dp_rank == 0 or not self.has_coordinator: # Notify client that we are pausing the loop.