Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions components/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@


def build_sampling_params(
request: Dict[str, Any], default_sampling_params: Dict[str, Any]
request: Dict[str, Any],
default_sampling_params: Dict[str, Any],
model_max_len: int | None = None,
) -> SamplingParams:
"""
Build SamplingParams from a PreprocessedRequest.
Expand Down Expand Up @@ -57,6 +59,18 @@ def build_sampling_params(
continue
setattr(sampling_params, key, value)

# If max_tokens wasn't provided (None or missing), compute a dynamic default
try:
provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)
token_ids = request.get("token_ids", [])
input_length = len(token_ids)
if model_max_len is not None and (provided_max_tokens is None):
# Ensure at least 1 token generation by default when possible
dynamic_default = max(1, model_max_len - input_length)
sampling_params.max_tokens = dynamic_default
except Exception:
pass

return sampling_params


Expand All @@ -65,14 +79,22 @@ class BaseWorkerHandler(ABC):
Request handler for the generate and clear_kv_blocks endpoints.
"""

def __init__(self, runtime, component, engine, default_sampling_params):
def __init__(
self,
runtime,
component,
engine,
default_sampling_params,
model_max_len: int | None = None,
):
self.runtime = runtime
self.component = component
self.engine_client = engine
self.default_sampling_params = default_sampling_params
self.kv_publishers: list[ZmqKvEventPublisher] | None = None
self.engine_monitor = VllmEngineMonitor(runtime, engine)
self.image_loader = ImageLoader()
self.model_max_len = model_max_len

@abstractmethod
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
Expand Down Expand Up @@ -212,8 +234,11 @@ def __init__(
component,
engine,
default_sampling_params,
model_max_len: int | None = None,
):
super().__init__(runtime, component, engine, default_sampling_params)
super().__init__(
runtime, component, engine, default_sampling_params, model_max_len
)

async def generate(self, request, context):
# Use context ID for request tracking and correlation
Expand All @@ -228,7 +253,9 @@ async def generate(self, request, context):
)

# Build sampling params from request
sampling_params = build_sampling_params(request, self.default_sampling_params)
sampling_params = build_sampling_params(
request, self.default_sampling_params, self.model_max_len
)

# Extract disaggregated_params from request (set by prefill router in Rust frontend)
disaggregated_params = request.get("disaggregated_params")
Expand Down Expand Up @@ -259,8 +286,17 @@ async def generate(self, request, context):


class PrefillWorkerHandler(BaseWorkerHandler):
def __init__(self, runtime, component, engine, default_sampling_params):
super().__init__(runtime, component, engine, default_sampling_params)
def __init__(
self,
runtime,
component,
engine,
default_sampling_params,
model_max_len: int | None = None,
):
super().__init__(
runtime, component, engine, default_sampling_params, model_max_len
)

async def generate(self, request, context):
# Use context ID for request tracking and correlation with decode phase
Expand All @@ -276,7 +312,9 @@ async def generate(self, request, context):
)

# Build sampling params from request using shared utility
sampling_params = build_sampling_params(request, self.default_sampling_params)
sampling_params = build_sampling_params(
request, self.default_sampling_params, self.model_max_len
)

# Configure for prefill-only mode with remote decode
if sampling_params.extra_args is None:
Expand Down
7 changes: 6 additions & 1 deletion components/src/dynamo/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)

handler = PrefillWorkerHandler(
runtime, component, engine_client, default_sampling_params
runtime,
component,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
)

# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
Expand Down Expand Up @@ -424,6 +428,7 @@ async def init(runtime: DistributedRuntime, config: Config):
component,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
)

# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
Expand Down
Loading