Skip to content
Open
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions components/src/dynamo/vllm/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Config:
multimodal_decode_worker: bool = False
multimodal_encode_prefill_worker: bool = False
mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:"
frontend_decoding: bool = False
# dump config to file
dump_config_to: Optional[str] = None

Expand Down Expand Up @@ -175,6 +176,16 @@ def parse_args() -> Config:
"'USER: <image> please describe the image ASSISTANT:'."
),
)
parser.add_argument(
"--frontend-decoding",
action="store_true",
help=(
"EXPERIMENTAL: Enable frontend decoding of multimodal images. "
"When enabled, images are decoded in the Rust frontend and transferred to the backend via NIXL RDMA. "
"Requires building Dynamo's Rust components with '--features media-nixl'. "
"Without this flag, images are decoded in the Python backend (default behavior)."
),
)
parser.add_argument(
"--store-kv",
type=str,
Expand Down
85 changes: 70 additions & 15 deletions components/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Final

import PIL
import torch
from vllm.inputs import TokensPrompt
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.exceptions import EngineDeadError

import dynamo.nixl_connect as connect
from dynamo.llm import ZmqKvEventPublisher
from dynamo.nixl_connect import OperationKind, RdmaMetadata, SerializedDescriptor
from dynamo.runtime.logging import configure_dynamo_logging

from .engine_monitor import VllmEngineMonitor
Expand Down Expand Up @@ -93,6 +97,7 @@ def __init__(
self.kv_publishers: list[ZmqKvEventPublisher] | None = None
self.engine_monitor = VllmEngineMonitor(runtime, engine)
self.image_loader = ImageLoader()
self._connector = None # Lazy-initialized on first Decoded variant
self.temp_dirs: list[tempfile.TemporaryDirectory] = []
self.model_max_len = model_max_len

Expand Down Expand Up @@ -150,11 +155,63 @@ def cleanup(self):
except Exception as e:
logger.warning(f"Failed to clean up temp directory: {e}")

async def _read_decoded_image_via_nixl(
self, decoded_meta: Dict[str, Any]
) -> PIL.Image.Image:
"""Read decoded image via NIXL RDMA and convert to PIL.Image."""
# Lazy-init connector
if self._connector is None:
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("NIXL connector initialized for decoded media")

# Extract fields
meta_str = decoded_meta["nixl_metadata"]
desc = decoded_meta["nixl_descriptor"]
shape = decoded_meta["shape"]

# Create tensor to receive RDMA data
tensor = torch.empty(shape, dtype=torch.uint8)

# Build RdmaMetadata from frontend-provided descriptor
# Frontend sends compressed metadata (matches Python nixl_connect)
rdma_meta = RdmaMetadata(
descriptors=[
SerializedDescriptor(
device="cpu"
if desc.get("mem_type") == "Dram"
else f"cuda:{desc.get('device_id', 0)}",
ptr=desc["addr"],
size=desc["size"],
)
],
nixl_metadata=meta_str,
notification_key=f"img-{shape}",
operation_kind=int(OperationKind.READ),
)

# RDMA read
read_op = await self._connector.begin_read(
rdma_meta, connect.Descriptor(tensor)
)
await read_op.wait_for_completion()
Comment on lines +158 to +197
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a NIXL expert, so please let me know if I can be doing anything here better.


# Convert to PIL.Image (assume RGB, handle RGBA/grayscale)
arr = tensor.numpy()
modes = {1: "L", 3: "RGB", 4: "RGBA"}
if modes[shape[2]] == "L":
arr = arr.squeeze(-1)
return PIL.Image.fromarray(arr, modes[shape[2]])

async def _extract_multimodal_data(
self, request: Dict[str, Any]
) -> Dict[str, Any] | None:
"""
Extract and decode multimodal data from PreprocessedRequest.

Supports two variants:
1. Url: Frontend passes URL, backend decodes
2. Decoded: Frontend decoded, NIXL RDMA transfer
"""
if "multi_modal_data" not in request or request["multi_modal_data"] is None:
return None
Expand All @@ -165,22 +222,20 @@ async def _extract_multimodal_data(
# Process image_url entries
images = []
for item in mm_map.get(IMAGE_URL_KEY, []):
if isinstance(item, dict) and URL_VARIANT_KEY in item:
if isinstance(item, dict) and DECODED_VARIANT_KEY in item:
decoded_meta = item[DECODED_VARIANT_KEY]
image = await self._read_decoded_image_via_nixl(decoded_meta)
images.append(image)
logger.info(
f"Using DECODED path: Loaded image via NIXL RDMA "
f"(shape={decoded_meta.get('shape')}, dtype={decoded_meta.get('dtype')})"
)
elif isinstance(item, dict) and URL_VARIANT_KEY in item:
url = item[URL_VARIANT_KEY]
try:
# ImageLoader supports both data: and http(s): URLs with caching
image = await self.image_loader.load_image(url)
images.append(image)
logger.debug(f"Loaded image from URL: {url[:80]}...")
except Exception:
logger.exception(f"Failed to load image from {url[:80]}...")
raise
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
# Decoded support from PRs #3971/#3988 (frontend decoding + NIXL transfer)
# Will contain NIXL metadata for direct memory access
# TODO: Implement NIXL read when PRs merge
logger.warning(
"Decoded multimodal data not yet supported in standard worker"
image = await self.image_loader.load_image(url)
images.append(image)
logger.info(
f"Using URL path: Loaded image from URL (type={url.split(':')[0]})"
)

if images:
Expand Down
22 changes: 22 additions & 0 deletions components/src/dynamo/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,26 @@ async def register_vllm_model(
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
runtime_config.data_parallel_size = data_parallel_size

# Conditionally enable frontend decoding if --frontend-decoding flag is set
media_decoder = None
media_fetcher = None
if config.frontend_decoding:
try:
from dynamo.llm import MediaDecoder, MediaFetcher

media_decoder = MediaDecoder()
media_fetcher = MediaFetcher()
logger.info(
"Frontend decoding enabled: images will be decoded in Rust frontend "
"and transferred via NIXL RDMA"
)
except ImportError as e:
raise RuntimeError(
"Frontend decoding (--frontend-decoding) requires building Dynamo's "
"Rust components with '--features media-nixl'. "
f"Import failed: {e}"
) from e

await register_llm(
model_input,
model_type,
Expand All @@ -319,6 +339,8 @@ async def register_vllm_model(
migration_limit=migration_limit,
runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template,
media_decoder=media_decoder,
media_fetcher=media_fetcher,
)


Expand Down
117 changes: 116 additions & 1 deletion lib/bindings/python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading