Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm/config/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ class SchedulerConfig:
while a larger value (e.g., 10) reduces host overhead and may increase throughput
by batching multiple tokens before sending."""

balance_scheduling: bool = False
"""EXPERIMENTAL: If set to True, perform balance scheduling. This may help
increase output throughput and reduce TPOT in v1 sheduler. However, TTFT
may degrade in some scenarios. Furthermore, enabling this feature is not
recommended in scenarios where PD is separated.
"""

def get_scheduler_cls(self) -> type["SchedulerInterface"]:
if self.scheduler_cls is None:
if self.async_scheduling:
Expand Down
17 changes: 17 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from collections.abc import Iterable
from typing import Any

import torch
import torch.distributed as dist
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
Expand Down Expand Up @@ -191,6 +193,16 @@
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

# Balance scheduling.
if self.vllm_config.scheduler_config.balance_scheduling:
self.balance_queue = [
torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(self.vllm_config.parallel_config.data_parallel_size)
]

Check failure on line 200 in vllm/v1/core/sched/scheduler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/core/sched/scheduler.py:200:89: E501 Line too long (132 > 88)

def balance_gather(self, dp_group):
runing_tensor = torch.tensor([len(self.running)], dtype=torch.int, device="cpu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a typo in the variable name runing_tensor. It should be running_tensor to improve code clarity and maintainability.

Suggested change
runing_tensor = torch.tensor([len(self.running)], dtype=torch.int, device="cpu")
running_tensor = torch.tensor([len(self.running)], dtype=torch.int, device="cpu")

dist.all_gather(self.balance_queue, runing_tensor, group=dp_group)

def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
Expand Down Expand Up @@ -407,6 +419,11 @@
if len(self.running) == self.max_num_running_reqs:
break

if self.vllm_config.scheduler_config.balance_scheduling:
balance_flag = max(t.item() for t in self.balance_queue) == self.max_num_running_reqs
if balance_flag:

Check failure on line 424 in vllm/v1/core/sched/scheduler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/core/sched/scheduler.py:424:89: E501 Line too long (105 > 88)
break

request = self.waiting.peek_request()

# KVTransfer: skip request if still waiting for remote kvs.
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,9 @@
local_unfinished_reqs
)

if self.vllm_config.scheduler_config.balance_scheduling:
self.scheduler.balance_gather(self.dp_group)

Check failure on line 1233 in vllm/v1/engine/core.py

View workflow job for this annotation

GitHub Actions / pre-commit

"SchedulerInterface" has no attribute "balance_gather" [attr-defined]

Check failure on line 1233 in vllm/v1/engine/core.py

View workflow job for this annotation

GitHub Actions / pre-commit

"SchedulerInterface" has no attribute "balance_gather" [attr-defined]

Check failure on line 1233 in vllm/v1/engine/core.py

View workflow job for this annotation

GitHub Actions / pre-commit

"SchedulerInterface" has no attribute "balance_gather" [attr-defined]

Check failure on line 1233 in vllm/v1/engine/core.py

View workflow job for this annotation

GitHub Actions / pre-commit

"SchedulerInterface" has no attribute "balance_gather" [attr-defined]

if not self.engines_running:
if self.dp_rank == 0 or not self.has_coordinator:
# Notify client that we are pausing the loop.
Expand Down
Loading