diff --git a/Cargo.lock b/Cargo.lock index 75d9af447b..05760d0f36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2678,11 +2678,13 @@ dependencies = [ "derive_builder", "dialoguer", "dynamo-async-openai", + "dynamo-memory", "dynamo-parsers", "dynamo-runtime", "either", "erased-serde", "etcd-client", + "flate2", "futures", "futures-util", "galil-seiferas", @@ -4641,6 +4643,7 @@ dependencies = [ "ravif", "rayon", "rgb", + "serde", "tiff", "zune-core 0.5.0", "zune-jpeg 0.5.5", diff --git a/components/src/dynamo/vllm/args.py b/components/src/dynamo/vllm/args.py index d20e1e8fc8..9fe555a8ef 100644 --- a/components/src/dynamo/vllm/args.py +++ b/components/src/dynamo/vllm/args.py @@ -65,6 +65,7 @@ class Config: multimodal_decode_worker: bool = False multimodal_encode_prefill_worker: bool = False mm_prompt_template: str = "USER: \n ASSISTANT:" + frontend_decoding: bool = False # dump config to file dump_config_to: Optional[str] = None @@ -175,6 +176,16 @@ def parse_args() -> Config: "'USER: please describe the image ASSISTANT:'." ), ) + parser.add_argument( + "--frontend-decoding", + action="store_true", + help=( + "EXPERIMENTAL: Enable frontend decoding of multimodal images. " + "When enabled, images are decoded in the Rust frontend and transferred to the backend via NIXL RDMA. " + "Requires building Dynamo's Rust components with '--features media-nixl'. " + "Without this flag, images are decoded in the Python backend (default behavior)." + ), + ) parser.add_argument( "--store-kv", type=str, diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index ba49529955..268cea72df 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -9,12 +9,16 @@ from contextlib import asynccontextmanager from typing import Any, AsyncGenerator, Dict, Final +import PIL +import torch from vllm.inputs import TokensPrompt from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.v1.engine.exceptions import EngineDeadError +import dynamo.nixl_connect as connect from dynamo.llm import ZmqKvEventPublisher +from dynamo.nixl_connect import OperationKind, RdmaMetadata, SerializedDescriptor from dynamo.runtime.logging import configure_dynamo_logging from .engine_monitor import VllmEngineMonitor @@ -93,6 +97,7 @@ def __init__( self.kv_publishers: list[ZmqKvEventPublisher] | None = None self.engine_monitor = VllmEngineMonitor(runtime, engine) self.image_loader = ImageLoader() + self._connector = None # Lazy-initialized on first Decoded variant self.temp_dirs: list[tempfile.TemporaryDirectory] = [] self.model_max_len = model_max_len @@ -150,11 +155,63 @@ def cleanup(self): except Exception as e: logger.warning(f"Failed to clean up temp directory: {e}") + async def _read_decoded_image_via_nixl( + self, decoded_meta: Dict[str, Any] + ) -> PIL.Image.Image: + """Read decoded image via NIXL RDMA and convert to PIL.Image.""" + # Lazy-init connector + if self._connector is None: + self._connector = connect.Connector() + await self._connector.initialize() + logger.info("NIXL connector initialized for decoded media") + + # Extract fields + meta_str = decoded_meta["nixl_metadata"] + desc = decoded_meta["nixl_descriptor"] + shape = decoded_meta["shape"] + + # Create tensor to receive RDMA data + tensor = torch.empty(shape, dtype=torch.uint8) + + # Build RdmaMetadata from frontend-provided descriptor + # Frontend sends compressed metadata (matches Python nixl_connect) + rdma_meta = RdmaMetadata( + descriptors=[ + SerializedDescriptor( + device="cpu" + if desc.get("mem_type") == "Dram" + else f"cuda:{desc.get('device_id', 0)}", + ptr=desc["addr"], + size=desc["size"], + ) + ], + nixl_metadata=meta_str, + notification_key=f"img-{shape}", + operation_kind=int(OperationKind.READ), + ) + + # RDMA read + read_op = await self._connector.begin_read( + rdma_meta, connect.Descriptor(tensor) + ) + await read_op.wait_for_completion() + + # Convert to PIL.Image (assume RGB, handle RGBA/grayscale) + arr = tensor.numpy() + modes = {1: "L", 3: "RGB", 4: "RGBA"} + if modes[shape[2]] == "L": + arr = arr.squeeze(-1) + return PIL.Image.fromarray(arr, modes[shape[2]]) + async def _extract_multimodal_data( self, request: Dict[str, Any] ) -> Dict[str, Any] | None: """ Extract and decode multimodal data from PreprocessedRequest. + + Supports two variants: + 1. Url: Frontend passes URL, backend decodes + 2. Decoded: Frontend decoded, NIXL RDMA transfer """ if "multi_modal_data" not in request or request["multi_modal_data"] is None: return None @@ -165,22 +222,20 @@ async def _extract_multimodal_data( # Process image_url entries images = [] for item in mm_map.get(IMAGE_URL_KEY, []): - if isinstance(item, dict) and URL_VARIANT_KEY in item: + if isinstance(item, dict) and DECODED_VARIANT_KEY in item: + decoded_meta = item[DECODED_VARIANT_KEY] + image = await self._read_decoded_image_via_nixl(decoded_meta) + images.append(image) + logger.info( + f"Using DECODED path: Loaded image via NIXL RDMA " + f"(shape={decoded_meta.get('shape')}, dtype={decoded_meta.get('dtype')})" + ) + elif isinstance(item, dict) and URL_VARIANT_KEY in item: url = item[URL_VARIANT_KEY] - try: - # ImageLoader supports both data: and http(s): URLs with caching - image = await self.image_loader.load_image(url) - images.append(image) - logger.debug(f"Loaded image from URL: {url[:80]}...") - except Exception: - logger.exception(f"Failed to load image from {url[:80]}...") - raise - elif isinstance(item, dict) and DECODED_VARIANT_KEY in item: - # Decoded support from PRs #3971/#3988 (frontend decoding + NIXL transfer) - # Will contain NIXL metadata for direct memory access - # TODO: Implement NIXL read when PRs merge - logger.warning( - "Decoded multimodal data not yet supported in standard worker" + image = await self.image_loader.load_image(url) + images.append(image) + logger.info( + f"Using URL path: Loaded image from URL (type={url.split(':')[0]})" ) if images: diff --git a/components/src/dynamo/vllm/main.py b/components/src/dynamo/vllm/main.py index 12bc9788b9..af919cc083 100644 --- a/components/src/dynamo/vllm/main.py +++ b/components/src/dynamo/vllm/main.py @@ -309,6 +309,26 @@ async def register_vllm_model( data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1) runtime_config.data_parallel_size = data_parallel_size + # Conditionally enable frontend decoding if --frontend-decoding flag is set + media_decoder = None + media_fetcher = None + if config.frontend_decoding: + try: + from dynamo.llm import MediaDecoder, MediaFetcher + + media_decoder = MediaDecoder() + media_fetcher = MediaFetcher() + logger.info( + "Frontend decoding enabled: images will be decoded in Rust frontend " + "and transferred via NIXL RDMA" + ) + except ImportError as e: + raise RuntimeError( + "Frontend decoding (--frontend-decoding) requires building Dynamo's " + "Rust components with '--features media-nixl'. " + f"Import failed: {e}" + ) from e + await register_llm( model_input, model_type, @@ -319,6 +339,8 @@ async def register_vllm_model( migration_limit=migration_limit, runtime_config=runtime_config, custom_template_path=config.custom_jinja_template, + media_decoder=media_decoder, + media_fetcher=media_fetcher, ) diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index 9ac49b7689..29a6ca07d3 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -371,7 +371,7 @@ version = "0.33.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1826f2e4cfc2cd19ee53c42fbf68e2f81ec21108e0b7ecf6a71cf062137360fc" dependencies = [ - "bindgen", + "bindgen 0.69.5", "cc", "cmake", "dunce", @@ -572,6 +572,26 @@ dependencies = [ "syn 2.0.110", ] +[[package]] +name = "bindgen" +version = "0.71.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" +dependencies = [ + "bitflags 2.9.3", + "cexpr", + "clang-sys", + "itertools 0.11.0", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 2.1.1", + "shlex", + "syn 2.0.106", +] + [[package]] name = "bit-set" version = "0.5.3" @@ -1533,6 +1553,13 @@ dependencies = [ "uuid", ] +[[package]] +name = "dynamo-config" +version = "0.7.0" +dependencies = [ + "anyhow", +] + [[package]] name = "dynamo-llm" version = "0.7.0" @@ -1561,11 +1588,13 @@ dependencies = [ "derive_builder", "dialoguer", "dynamo-async-openai", + "dynamo-memory", "dynamo-parsers", "dynamo-runtime", "either", "erased-serde", "etcd-client", + "flate2", "futures", "futures-util", "galil-seiferas", @@ -1581,6 +1610,7 @@ dependencies = [ "modelexpress-client", "modelexpress-common", "ndarray", + "nixl-sys", "ndarray-interp", "ndarray-npy", "offset-allocator", @@ -1622,6 +1652,22 @@ dependencies = [ "zeromq", ] +[[package]] +name = "dynamo-memory" +version = "0.7.0" +dependencies = [ + "anyhow", + "cudarc", + "dynamo-config", + "libc", + "nix 0.30.1", + "nixl-sys", + "offset-allocator", + "serde", + "thiserror 2.0.16", + "tracing", +] + [[package]] name = "dynamo-parsers" version = "0.7.0" @@ -1708,6 +1754,7 @@ dependencies = [ "local-ip-address", "log", "nid", + "nix 0.29.0", "nix", "notify", "nuid", @@ -2972,6 +3019,7 @@ dependencies = [ "ravif", "rayon", "rgb", + "serde", "tiff", "zune-core", "zune-jpeg", @@ -4029,6 +4077,34 @@ dependencies = [ "libc", ] +[[package]] +name = "nix" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" +dependencies = [ + "bitflags 2.9.3", + "cfg-if 1.0.3", + "cfg_aliases", + "libc", +] + +[[package]] +name = "nixl-sys" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a73b92494c94b2ff2d004cd9274d966863089e867dc9cd98bc640aefe7622036" +dependencies = [ + "bindgen 0.71.1", + "cc", + "libc", + "os_info", + "pkg-config", + "serde", + "thiserror 2.0.16", + "tracing", +] + [[package]] name = "nkeys" version = "0.4.5" @@ -4440,6 +4516,24 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "os_info" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0e1ac5fde8d43c34139135df8ea9ee9465394b2d8d20f032d38998f64afffc3" +dependencies = [ + "log", + "plist", + "serde", + "windows-sys 0.52.0", +] + +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking" version = "2.2.1" @@ -4668,6 +4762,19 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plist" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "740ebea15c5d1428f910cd1a5f52cebf8d25006245ed8ade92702f4943d91e07" +dependencies = [ + "base64 0.22.1", + "indexmap 2.11.0", + "quick-xml", + "serde", + "time", +] + [[package]] name = "png" version = "0.18.0" @@ -5127,6 +5234,15 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" +[[package]] +name = "quick-xml" +version = "0.38.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66c2058c55a409d601666cffe35f04333cf1013010882cec174a7467cd4e21c" +dependencies = [ + "memchr", +] + [[package]] name = "quinn" version = "0.11.9" diff --git a/lib/bindings/python/Cargo.toml b/lib/bindings/python/Cargo.toml index 845011e41e..e661d95a87 100644 --- a/lib/bindings/python/Cargo.toml +++ b/lib/bindings/python/Cargo.toml @@ -23,6 +23,7 @@ crate-type = ["cdylib", "rlib"] [features] default = [] +media-nixl = ["dynamo-llm/media-nixl"] [dependencies] dynamo-llm = { path = "../../llm" } diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index aa194df99d..a5c2d40655 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -24,6 +24,7 @@ testing-etcd = [] block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:nix", "dep:aligned-vec"] cuda = ["dep:cudarc"] integration = ["dynamo-runtime/integration"] +media-nixl = ["dep:nixl-sys", "dep:dynamo-memory"] [[bench]] name = "tokenizer" @@ -33,9 +34,11 @@ harness = false name = "transfer_context_v2" harness = false required-features = ["block-manager", "testing-cuda"] + [dependencies] # repo dynamo-runtime = { workspace = true } +dynamo-memory = { path = "../memory", version = "0.7.0", optional = true } # workspace aho-corasick = "1.1" @@ -145,9 +148,10 @@ json-five = { version = "0.3" } # media loading in the preprocessor reqwest = { workspace = true } base64 = { version = "0.22" } -image = { version = "0.25" } +image = { version = "0.25", features = ["serde"] } tokio-rayon = {version = "2" } ndarray = { version = "0.16" } +flate2 = { version = "1.0" } ndarray-npy = { version = "0.9" } ndarray-interp = { version = "0.5" } diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index ac9244a58b..8dd3dcb947 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -27,7 +27,8 @@ use std::{collections::HashMap, pin::Pin, sync::Arc}; use tracing; use crate::model_card::{ModelDeploymentCard, ModelInfo}; -use crate::preprocessor::media::MediaLoader; +#[cfg(feature = "media-nixl")] +use crate::preprocessor::media::{MediaDecoder, MediaFetcher, MediaLoader}; use crate::preprocessor::prompt::OAIChatLikeRequest; use crate::protocols::common::preprocessor::{ MultimodalData, MultimodalDataMap, PreprocessedRequestBuilder, @@ -114,6 +115,7 @@ pub struct OpenAIPreprocessor { /// Per-model runtime configuration propagated to response generator (e.g., reasoning/tool parser) runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig, tool_call_parser: Option, + #[cfg(feature = "media-nixl")] media_loader: Option, } @@ -143,7 +145,13 @@ impl OpenAIPreprocessor { // // Initialize runtime config from the ModelDeploymentCard let runtime_config = mdc.runtime_config.clone(); - let media_loader = None; // TODO: enable with decoder config from MDC + + #[cfg(feature = "media-nixl")] + let media_loader = match mdc.media_decoder { + Some(media_decoder) => Some(MediaLoader::new(media_decoder, mdc.media_fetcher)?), + None => None, + }; + Ok(Arc::new(Self { formatter, tokenizer, @@ -151,6 +159,7 @@ impl OpenAIPreprocessor { mdcsum, runtime_config, tool_call_parser, + #[cfg(feature = "media-nixl")] media_loader, })) } @@ -280,7 +289,9 @@ impl OpenAIPreprocessor { let messages = request.messages(); let message_count = messages.len().unwrap_or(0); let mut media_map: MultimodalDataMap = HashMap::new(); - let mut fetch_tasks = Vec::new(); + #[cfg(feature = "media-nixl")] + let mut fetch_tasks: Vec<(String, ChatCompletionRequestUserMessageContentPart)> = + Vec::new(); for idx in 0..message_count { let msg = messages @@ -313,29 +324,39 @@ impl OpenAIPreprocessor { _ => continue, }; + #[cfg(feature = "media-nixl")] if self.media_loader.is_some() { fetch_tasks.push((type_str, content_part.clone())); - } else { - // No loader, just pass the URL through - media_map - .entry(type_str) - .or_default() - .push(MultimodalData::Url(url)); + continue; } + + //Fallback: ust pass the URL through + media_map + .entry(type_str) + .or_default() + .push(MultimodalData::Url(url)); } } // Execute all fetch tasks + #[cfg(feature = "media-nixl")] if !fetch_tasks.is_empty() { let loader = self.media_loader.as_ref().unwrap(); - let _results = futures::future::join_all( + let results = futures::future::join_all( fetch_tasks .iter() .map(|(_, content_part)| loader.fetch_and_decode_media_part(content_part)), ) .await; - // TODO: decode and pass NIXL descriptors to the media map + for ((type_str, _), result) in fetch_tasks.into_iter().zip(results.into_iter()) { + // if one item fails, errors the whole request, other items will be cleaned up by Drop + let rdma_descriptor = result?; + media_map + .entry(type_str) + .or_default() + .push(MultimodalData::Decoded(rdma_descriptor)); + } } if !media_map.is_empty() { diff --git a/lib/llm/src/preprocessor/media.rs b/lib/llm/src/preprocessor/media.rs index 0c0e3e6b12..65566c9937 100644 --- a/lib/llm/src/preprocessor/media.rs +++ b/lib/llm/src/preprocessor/media.rs @@ -4,7 +4,12 @@ mod common; mod decoders; mod loader; +mod rdma; pub use common::EncodedMediaData; pub use decoders::{Decoder, ImageDecoder, MediaDecoder}; pub use loader::{MediaFetcher, MediaLoader}; + +pub use rdma::{DecodedMediaData, RdmaMediaDataDescriptor}; +#[cfg(feature = "media-nixl")] +pub use rdma::{get_nixl_agent, get_nixl_metadata}; diff --git a/lib/llm/src/preprocessor/media/README.md b/lib/llm/src/preprocessor/media/README.md new file mode 100644 index 0000000000..fede33bc9f --- /dev/null +++ b/lib/llm/src/preprocessor/media/README.md @@ -0,0 +1,63 @@ +# Media decoding in the frontend + + +This component performs media download, base64 decoding, media decoding and NIXL registration. Today, this is used in the OpenAI preprocessor, to transform multimodal inputs (image_url, video_url, audio_url) into fully decoded data (pixel values, ...) accessible to the backends via NIXL. + +## Usage + +Media decoding is enabled when registering the MDC: + +Set HTTP download options: + +```python +from dynamo.llm import MediaFetcher +fetcher = MediaFetcher() +fetcher.user_agent("dynamo") +fetcher.timeout_ms(15000) +fetcher.allow_direct_ip(True) +fetcher.allow_direct_port(False) +fetcher.allowed_media_domains(["google.com"]) +``` + +Set media decoding options: + +```python +from dynamo.llm import MediaDecoder +decoder = MediaDecoder() +decoder.image_decoder({"max_image_width": 4096, "max_image_height": 4096, "max_alloc": 16*1024*1024}) +``` + +And register the LLM as usual, adding the media configuration: + +```python +register_llm( + ..., + media_decoder=decoder, + media_fetcher=fetcher, +) +``` + + +## TODOs + +### Modalities + +- [x] Image decoding +- [ ] Video decoding +- [ ] Audio decoding + +### Performance + +- [x] Image SW decoding +- [ ] Video HW decoding (NVDEC) +- [ ] JPEG HW decoding (nvJPEG) +- [ ] Sparse video sampling (seek-forward) +- [ ] Memory slab pre-allocation/registration + +### Memory management +- [ ] Memory spilling to lower storage tiers +- [ ] Early-free memory on client notifications + +### Misc +- [ ] Observability on performance, memory usage and input distributions +- [ ] Per-request decoding options diff --git a/lib/llm/src/preprocessor/media/decoders.rs b/lib/llm/src/preprocessor/media/decoders.rs index aa546915ec..984aab1ea2 100644 --- a/lib/llm/src/preprocessor/media/decoders.rs +++ b/lib/llm/src/preprocessor/media/decoders.rs @@ -2,52 +2,14 @@ // SPDX-License-Identifier: Apache-2.0 use anyhow::Result; +use serde::{Deserialize, Serialize}; use super::common::EncodedMediaData; -use ndarray::{ArrayBase, Dimension, OwnedRepr}; -mod image; +use super::rdma::DecodedMediaData; +pub mod image; pub use image::{ImageDecoder, ImageMetadata}; -#[derive(Debug)] -pub enum DecodedMediaMetadata { - #[allow(dead_code)] // used in followup MR - Image(ImageMetadata), -} - -#[derive(Debug, PartialEq, Eq)] -pub enum DataType { - UINT8, -} - -// Decoded media data (image RGB, video frames pixels, ...) -#[derive(Debug)] -pub struct DecodedMediaData { - #[allow(dead_code)] // used in followup MR - pub(crate) data: Vec, - #[allow(dead_code)] // used in followup MR - pub(crate) shape: Vec, - #[allow(dead_code)] // used in followup MR - pub(crate) dtype: DataType, - #[allow(dead_code)] // used in followup MR - pub(crate) metadata: Option, -} - -// convert Array{N} to DecodedMediaData -// TODO: Array1 for audio -impl From, D>> for DecodedMediaData { - fn from(array: ArrayBase, D>) -> Self { - let shape = array.shape().to_vec(); - let (data, _) = array.into_raw_vec_and_offset(); - Self { - data, - shape, - dtype: DataType::UINT8, - metadata: None, - } - } -} - #[async_trait::async_trait] pub trait Decoder: Clone + Send + 'static { fn decode(&self, data: EncodedMediaData) -> Result; @@ -67,3 +29,8 @@ pub struct MediaDecoder { pub image_decoder: ImageDecoder, // TODO: video, audio decoders } + +#[derive(Serialize, Deserialize, Clone, Copy, Debug)] +pub enum DecodedMediaMetadata { + Image(ImageMetadata), +} diff --git a/lib/llm/src/preprocessor/media/decoders/image.rs b/lib/llm/src/preprocessor/media/decoders/image.rs index e6c857d33b..09de4db8d2 100644 --- a/lib/llm/src/preprocessor/media/decoders/image.rs +++ b/lib/llm/src/preprocessor/media/decoders/image.rs @@ -6,14 +6,15 @@ use std::io::Cursor; use anyhow::Result; use image::{ColorType, GenericImageView, ImageFormat, ImageReader}; use ndarray::Array3; +use serde::{Deserialize, Serialize}; use super::super::common::EncodedMediaData; -use super::super::decoders::{DecodedMediaData, DecodedMediaMetadata}; -use super::Decoder; +use super::super::rdma::DecodedMediaData; +use super::{DecodedMediaMetadata, Decoder}; const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct ImageDecoder { #[serde(default)] @@ -36,18 +37,15 @@ impl Default for ImageDecoder { } #[allow(clippy::upper_case_acronyms)] -#[derive(Debug)] +#[derive(Serialize, Deserialize, Clone, Copy, Debug)] pub enum ImageLayout { HWC, } -#[derive(Debug)] +#[derive(Serialize, Deserialize, Clone, Copy, Debug)] pub struct ImageMetadata { - #[allow(dead_code)] // used in followup MR pub(crate) format: Option, - #[allow(dead_code)] // used in followup MR pub(crate) color_type: ColorType, - #[allow(dead_code)] // used in followup MR pub(crate) layout: ImageLayout, } @@ -78,8 +76,8 @@ impl Decoder for ImageDecoder { let (width, height) = img.dimensions(); let shape = (height as usize, width as usize, n_channels as usize); let array = Array3::from_shape_vec(shape, data)?; - let mut decoded: DecodedMediaData = array.into(); - decoded.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata { + let mut decoded: DecodedMediaData = array.try_into()?; + decoded.tensor_info.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata { format, color_type, layout: ImageLayout::HWC, @@ -90,7 +88,7 @@ impl Decoder for ImageDecoder { #[cfg(test)] mod tests { - use super::super::super::decoders::DataType; + use super::super::super::rdma::DataType; use super::*; use image::{DynamicImage, ImageBuffer}; use rstest::rstest; @@ -156,10 +154,10 @@ mod tests { let decoded = result.unwrap(); assert_eq!( - decoded.shape, + decoded.tensor_info.shape, vec![height as usize, width as usize, expected_channels as usize] ); - assert_eq!(decoded.dtype, DataType::UINT8); + assert_eq!(decoded.tensor_info.dtype, DataType::UINT8); } #[rstest] @@ -196,9 +194,12 @@ mod tests { format ); let decoded = result.unwrap(); - assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]); assert_eq!( - decoded.dtype, + decoded.tensor_info.shape, + vec![height as usize, width as usize, 3] + ); + assert_eq!( + decoded.tensor_info.dtype, DataType::UINT8, "dtype should be uint8 for case: {}", test_case @@ -236,11 +237,15 @@ mod tests { ); let decoded = result.unwrap(); - assert_eq!(decoded.shape.len(), 3, "Should have 3 dimensions"); - assert_eq!(decoded.shape[0], 1, "Height should be 1"); - assert_eq!(decoded.shape[1], 1, "Width should be 1"); assert_eq!( - decoded.dtype, + decoded.tensor_info.shape.len(), + 3, + "Should have 3 dimensions" + ); + assert_eq!(decoded.tensor_info.shape[0], 1, "Height should be 1"); + assert_eq!(decoded.tensor_info.shape[1], 1, "Width should be 1"); + assert_eq!( + decoded.tensor_info.dtype, DataType::UINT8, "dtype should be uint8 for {} channels {:?}", input_channels, diff --git a/lib/llm/src/preprocessor/media/loader.rs b/lib/llm/src/preprocessor/media/loader.rs index 91fc65d9bc..0d229d7437 100644 --- a/lib/llm/src/preprocessor/media/loader.rs +++ b/lib/llm/src/preprocessor/media/loader.rs @@ -8,8 +8,14 @@ use anyhow::Result; use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart; -use super::common::EncodedMediaData; -use super::decoders::{DecodedMediaData, Decoder, MediaDecoder}; +use super::decoders::MediaDecoder; +use super::rdma::RdmaMediaDataDescriptor; + +#[cfg(feature = "media-nixl")] +use { + super::common::EncodedMediaData, super::decoders::Decoder, super::rdma::get_nixl_agent, + dynamo_memory::nixl::NixlAgent, +}; const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo"; const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30); @@ -36,15 +42,19 @@ impl Default for MediaFetcher { } pub struct MediaLoader { + #[allow(dead_code)] media_decoder: MediaDecoder, + #[allow(dead_code)] http_client: reqwest::Client, media_fetcher: MediaFetcher, - // TODO: NIXL agent + #[cfg(feature = "media-nixl")] + nixl_agent: NixlAgent, } impl MediaLoader { - pub fn new(media_decoder: MediaDecoder, media_fetcher: MediaFetcher) -> Result { - let mut http_client_builder = + pub fn new(media_decoder: MediaDecoder, media_fetcher: Option) -> Result { + let media_fetcher = media_fetcher.unwrap_or_default(); + let mut http_client_builder: reqwest::ClientBuilder = reqwest::Client::builder().user_agent(&media_fetcher.user_agent); if let Some(timeout) = media_fetcher.timeout { @@ -53,10 +63,15 @@ impl MediaLoader { let http_client = http_client_builder.build()?; + #[cfg(feature = "media-nixl")] + let nixl_agent = get_nixl_agent()?; + Ok(Self { media_decoder, http_client, media_fetcher, + #[cfg(feature = "media-nixl")] + nixl_agent, }) } @@ -90,35 +105,43 @@ impl MediaLoader { &self, oai_content_part: &ChatCompletionRequestUserMessageContentPart, // TODO: request-level options - ) -> Result { - // fetch the media - // TODO: decode and NIXL-register - let decoded = match oai_content_part { - ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => { - let url = &image_part.image_url.url; - self.check_if_url_allowed(url)?; - let data = EncodedMediaData::from_url(url, &self.http_client).await?; - self.media_decoder.image_decoder.decode_async(data).await? - } - ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => { - let url = &video_part.video_url.url; - self.check_if_url_allowed(url)?; - EncodedMediaData::from_url(url, &self.http_client).await?; - anyhow::bail!("Video decoding is not supported yet"); - } - ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => { - anyhow::bail!("Audio decoding is not supported yet"); - } - _ => anyhow::bail!("Unsupported media type"), - }; + ) -> Result { + #[cfg(not(feature = "media-nixl"))] + anyhow::bail!( + "NIXL is not supported, cannot decode and register media data {oai_content_part:?}" + ); - Ok(decoded) + #[cfg(feature = "media-nixl")] + { + // fetch the media, decode and NIXL-register + let decoded = match oai_content_part { + ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => { + let url = &image_part.image_url.url; + self.check_if_url_allowed(url)?; + let data = EncodedMediaData::from_url(url, &self.http_client).await?; + self.media_decoder.image_decoder.decode_async(data).await? + } + ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => { + let url = &video_part.video_url.url; + self.check_if_url_allowed(url)?; + EncodedMediaData::from_url(url, &self.http_client).await?; + anyhow::bail!("Video decoding is not supported yet"); + } + ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => { + anyhow::bail!("Audio decoding is not supported yet"); + } + _ => anyhow::bail!("Unsupported media type"), + }; + + let rdma_descriptor = decoded.into_rdma_descriptor(&self.nixl_agent)?; + Ok(rdma_descriptor) + } } } -#[cfg(test)] +#[cfg(all(test, feature = "media-nixl"))] mod tests { - use super::super::decoders::DataType; + use super::super::rdma::DataType; use super::*; use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl}; @@ -143,7 +166,7 @@ mod tests { ..Default::default() }; - let loader = MediaLoader::new(media_decoder, fetcher).unwrap(); + let loader: MediaLoader = MediaLoader::new(media_decoder, fetcher).unwrap(); let image_url = ImageUrl::from(format!("{}/llm-optimize-deploy-graphic.png", server.url())); let content_part = ChatCompletionRequestUserMessageContentPart::ImageUrl( @@ -151,24 +174,48 @@ mod tests { ); let result = loader.fetch_and_decode_media_part(&content_part).await; - assert!( - result.is_ok(), - "Failed to fetch and decode image: {:?}", - result.err() - ); - let data = result.unwrap(); - assert_eq!(data.dtype, DataType::UINT8); + let descriptor = match result { + Ok(descriptor) => descriptor, + Err(e) if e.to_string().contains("NIXL agent is not available") => { + println!("test test_fetch_and_decode ... ignored (NIXL agent not available)"); + return; + } + Err(e) => panic!("Failed to fetch and decode image: {}", e), + }; + mock.assert_async().await; + assert_eq!(descriptor.tensor_info.dtype, DataType::UINT8); // Verify image dimensions: 1,999px × 1,125px (width × height) // Shape format is [height, width, channels] - assert_eq!(data.shape.len(), 3); - assert_eq!(data.shape[0], 1125, "Height should be 1125"); - assert_eq!(data.shape[1], 1999, "Width should be 1999"); - assert_eq!(data.shape[2], 4, "RGBA channels should be 4"); + assert_eq!(descriptor.tensor_info.shape.len(), 3); + assert_eq!( + descriptor.tensor_info.shape[0], 1125, + "Height should be 1125" + ); + assert_eq!( + descriptor.tensor_info.shape[1], 1999, + "Width should be 1999" + ); + assert_eq!( + descriptor.tensor_info.shape[2], 4, + "RGBA channels should be 4" + ); - mock.assert_async().await; + assert!( + descriptor.source_storage.is_some(), + "Source storage should be present" + ); + assert!( + descriptor.source_storage.unwrap().is_registered(), + "Source storage should be registered with NIXL" + ); } +} + +#[cfg(test)] +mod tests_non_nixl { + use super::*; #[test] fn test_direct_ip_blocked() { @@ -176,7 +223,7 @@ mod tests { allow_direct_ip: false, ..Default::default() }; - let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap(); + let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap(); let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap(); let result = loader.check_if_url_allowed(&url); @@ -196,7 +243,7 @@ mod tests { allow_direct_port: false, ..Default::default() }; - let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap(); + let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap(); let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap(); let result = loader.check_if_url_allowed(&url); @@ -220,7 +267,7 @@ mod tests { allowed_media_domains: Some(allowed_domains), ..Default::default() }; - let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap(); + let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap(); // Allowed domain should pass let url = url::Url::parse("https://trusted.com/image.jpg").unwrap(); diff --git a/lib/llm/src/preprocessor/media/rdma.rs b/lib/llm/src/preprocessor/media/rdma.rs new file mode 100644 index 0000000000..38363e1b71 --- /dev/null +++ b/lib/llm/src/preprocessor/media/rdma.rs @@ -0,0 +1,141 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use anyhow::Result; +use ndarray::{ArrayBase, Dimension, OwnedRepr}; +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "media-nixl")] +use { + base64::{Engine as _, engine::general_purpose}, + dynamo_memory::SystemStorage, + dynamo_memory::nixl::{self, NixlAgent, NixlDescriptor, RegisteredView}, + std::sync::Arc, +}; + +use super::decoders::DecodedMediaMetadata; + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +pub enum DataType { + UINT8, +} + +// Common tensor metadata shared between decoded and RDMA descriptors +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct MediaTensorInfo { + pub(crate) shape: Vec, + pub(crate) dtype: DataType, + pub(crate) metadata: Option, +} + +// Decoded media data (image RGB, video frames pixels, ...) +#[derive(Debug)] +pub struct DecodedMediaData { + #[cfg(feature = "media-nixl")] + pub(crate) data: SystemStorage, + pub(crate) tensor_info: MediaTensorInfo, +} + +// Decoded media data NIXL descriptor (sent to the next step in the pipeline / NATS) + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct RdmaMediaDataDescriptor { + // b64 agent metadata + #[cfg(feature = "media-nixl")] + pub(crate) nixl_metadata: String, + // tensor descriptor + #[cfg(feature = "media-nixl")] + pub(crate) nixl_descriptor: NixlDescriptor, + + #[serde(flatten)] + pub(crate) tensor_info: MediaTensorInfo, + + // reference to the actual data, kept alive while the rdma descriptor is alive + #[serde(skip, default)] + #[allow(dead_code)] + #[cfg(feature = "media-nixl")] + pub(crate) source_storage: Option>>, +} + +impl DecodedMediaData { + #[cfg(feature = "media-nixl")] + pub fn into_rdma_descriptor(self, nixl_agent: &NixlAgent) -> Result { + let source_storage = self.data; + let registered = nixl::register_with_nixl(source_storage, nixl_agent, None) + .map_err(|_| anyhow::anyhow!("Failed to register storage with NIXL"))?; + + let nixl_descriptor = registered.descriptor(); + let nixl_metadata = get_nixl_metadata(nixl_agent, registered.storage())?; + + Ok(RdmaMediaDataDescriptor { + nixl_metadata, + nixl_descriptor, + tensor_info: self.tensor_info, + // Keep registered storage alive + source_storage: Some(Arc::new(registered)), + }) + } +} + +// convert Array{N} to DecodedMediaData +// TODO: Array1 for audio + +impl TryFrom, D>> for DecodedMediaData { + type Error = anyhow::Error; + + fn try_from(array: ArrayBase, D>) -> Result { + let shape = array.shape().to_vec(); + + #[cfg(feature = "media-nixl")] + let (data_vec, _) = array.into_raw_vec_and_offset(); + #[cfg(feature = "media-nixl")] + let mut storage = SystemStorage::new(data_vec.len())?; + #[cfg(feature = "media-nixl")] + unsafe { + std::ptr::copy_nonoverlapping(data_vec.as_ptr(), storage.as_mut_ptr(), data_vec.len()); + } + + Ok(Self { + #[cfg(feature = "media-nixl")] + data: storage, + tensor_info: MediaTensorInfo { + shape, + dtype: DataType::UINT8, + metadata: None, + }, + }) + } +} + +// Get NIXL metadata for a descriptor +// Avoids cross-request leak possibility and reduces metadata size +// TODO: pre-allocate a fixed NIXL-registered RAM pool so metadata can be cached on the target? +#[cfg(feature = "media-nixl")] +pub fn get_nixl_metadata(agent: &NixlAgent, _storage: &SystemStorage) -> Result { + use flate2::Compression; + use flate2::write::ZlibEncoder; + use std::io::Write; + + // WAR: Until https://github.com/ai-dynamo/nixl/pull/970 is merged, can't use get_local_partial_md + let nixl_md = agent.raw_agent().get_local_md()?; + // let mut reg_desc_list = RegDescList::new(MemType::Dram)?; + // reg_desc_list.add_storage_desc(storage)?; + // let nixl_partial_md = agent.raw_agent().get_local_partial_md(®_desc_list, None)?; + + // Compress metadata before base64 encoding (matches Python nixl_connect behavior) + // Backend expects: b64: + // Note: Python nixl_connect automatically decompresses when seeing "b64:" prefix + let mut zlib_encoder = ZlibEncoder::new(Vec::new(), Compression::new(6)); + zlib_encoder.write_all(&nixl_md)?; + let compressed = zlib_encoder.finish()?; + + let b64_encoded = general_purpose::STANDARD.encode(&compressed); + Ok(format!("b64:{}", b64_encoded)) +} + +#[cfg(feature = "media-nixl")] +pub fn get_nixl_agent() -> Result { + let name = format!("media-loader-{}", uuid::Uuid::new_v4()); + let nixl_agent = NixlAgent::with_backends(&name, &["UCX"])?; + Ok(nixl_agent) +} diff --git a/lib/llm/src/protocols/common/preprocessor.rs b/lib/llm/src/protocols/common/preprocessor.rs index f4fc92f7f4..d8692c608e 100644 --- a/lib/llm/src/protocols/common/preprocessor.rs +++ b/lib/llm/src/protocols/common/preprocessor.rs @@ -6,6 +6,8 @@ use serde::{Deserialize, Serialize}; use super::{OutputOptions, SamplingOptions, StopConditions}; use crate::kv_router::RouterConfigOverride; +#[cfg(feature = "media-nixl")] +use crate::preprocessor::media::RdmaMediaDataDescriptor; use crate::protocols::TokenIdType; #[derive(Serialize, Deserialize, Debug, Clone)] @@ -20,7 +22,8 @@ pub struct PrefillResult { #[derive(Serialize, Deserialize, Debug, Clone)] pub enum MultimodalData { Url(url::Url), - // TODO: Decoded(DecodedMediaData), + #[cfg(feature = "media-nixl")] + Decoded(RdmaMediaDataDescriptor), } // multimodal map containing {mm_part_type: [data...]} @@ -40,6 +43,7 @@ pub struct PreprocessedRequest { #[builder(default)] #[serde(default, skip_serializing_if = "Option::is_none")] pub multi_modal_data: Option, + /// StopConditions are conditions that the inference engine will use to stop generation. pub stop_conditions: StopConditions,