From 5cb6890925867ad9819c0eb17222fd44cd1ee48e Mon Sep 17 00:00:00 2001 From: bin Date: Thu, 6 Nov 2025 14:29:09 +0000 Subject: [PATCH] support a dynamic default max_tokens for VLLM backend Signed-off-by: bin --- components/src/dynamo/vllm/handlers.py | 52 ++++++++++++++++++++++---- components/src/dynamo/vllm/main.py | 7 +++- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 6d02e0da53..149567b76f 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -29,7 +29,9 @@ def build_sampling_params( - request: Dict[str, Any], default_sampling_params: Dict[str, Any] + request: Dict[str, Any], + default_sampling_params: Dict[str, Any], + model_max_len: int | None = None, ) -> SamplingParams: """ Build SamplingParams from a PreprocessedRequest. @@ -57,6 +59,18 @@ def build_sampling_params( continue setattr(sampling_params, key, value) + # If max_tokens wasn't provided (None or missing), compute a dynamic default + try: + provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None) + token_ids = request.get("token_ids", []) + input_length = len(token_ids) + if model_max_len is not None and (provided_max_tokens is None): + # Ensure at least 1 token generation by default when possible + dynamic_default = max(1, model_max_len - input_length) + sampling_params.max_tokens = dynamic_default + except Exception: + pass + return sampling_params @@ -65,7 +79,14 @@ class BaseWorkerHandler(ABC): Request handler for the generate and clear_kv_blocks endpoints. """ - def __init__(self, runtime, component, engine, default_sampling_params): + def __init__( + self, + runtime, + component, + engine, + default_sampling_params, + model_max_len: int | None = None, + ): self.runtime = runtime self.component = component self.engine_client = engine @@ -73,6 +94,7 @@ def __init__(self, runtime, component, engine, default_sampling_params): self.kv_publishers: list[ZmqKvEventPublisher] | None = None self.engine_monitor = VllmEngineMonitor(runtime, engine) self.image_loader = ImageLoader() + self.model_max_len = model_max_len @abstractmethod async def generate(self, request, context) -> AsyncGenerator[dict, None]: @@ -212,8 +234,11 @@ def __init__( component, engine, default_sampling_params, + model_max_len: int | None = None, ): - super().__init__(runtime, component, engine, default_sampling_params) + super().__init__( + runtime, component, engine, default_sampling_params, model_max_len + ) async def generate(self, request, context): # Use context ID for request tracking and correlation @@ -228,7 +253,9 @@ async def generate(self, request, context): ) # Build sampling params from request - sampling_params = build_sampling_params(request, self.default_sampling_params) + sampling_params = build_sampling_params( + request, self.default_sampling_params, self.model_max_len + ) # Extract disaggregated_params from request (set by prefill router in Rust frontend) disaggregated_params = request.get("disaggregated_params") @@ -259,8 +286,17 @@ async def generate(self, request, context): class PrefillWorkerHandler(BaseWorkerHandler): - def __init__(self, runtime, component, engine, default_sampling_params): - super().__init__(runtime, component, engine, default_sampling_params) + def __init__( + self, + runtime, + component, + engine, + default_sampling_params, + model_max_len: int | None = None, + ): + super().__init__( + runtime, component, engine, default_sampling_params, model_max_len + ) async def generate(self, request, context): # Use context ID for request tracking and correlation with decode phase @@ -276,7 +312,9 @@ async def generate(self, request, context): ) # Build sampling params from request using shared utility - sampling_params = build_sampling_params(request, self.default_sampling_params) + sampling_params = build_sampling_params( + request, self.default_sampling_params, self.model_max_len + ) # Configure for prefill-only mode with remote decode if sampling_params.extra_args is None: diff --git a/components/src/dynamo/vllm/main.py b/components/src/dynamo/vllm/main.py index 060390d01b..d37f7cb5c0 100644 --- a/components/src/dynamo/vllm/main.py +++ b/components/src/dynamo/vllm/main.py @@ -317,7 +317,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config) handler = PrefillWorkerHandler( - runtime, component, engine_client, default_sampling_params + runtime, + component, + engine_client, + default_sampling_params, + getattr(getattr(vllm_config, "model_config", None), "max_model_len", None), ) # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine) @@ -424,6 +428,7 @@ async def init(runtime: DistributedRuntime, config: Config): component, engine_client, default_sampling_params, + getattr(getattr(vllm_config, "model_config", None), "max_model_len", None), ) # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)