diff --git a/Cargo.lock b/Cargo.lock index 86c26e48f4..e1dc436ab6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2241,6 +2241,7 @@ dependencies = [ "derive_builder", "dialoguer", "dynamo-async-openai", + "dynamo-memory", "dynamo-parsers", "dynamo-runtime", "either", @@ -3992,6 +3993,7 @@ dependencies = [ "ravif", "rayon", "rgb", + "serde", "tiff", "zune-core", "zune-jpeg", diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index 5eeaa775fc..0c6059530d 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -2966,6 +2966,7 @@ dependencies = [ "ravif", "rayon", "rgb", + "serde", "tiff", "zune-core", "zune-jpeg", @@ -3912,6 +3913,16 @@ dependencies = [ "pxfm", ] +[[package]] +name = "moxcms" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fbdd3d7436f8b5e892b8b7ea114271ff0fa00bc5acae845d53b07d498616ef6" +dependencies = [ + "num-traits", + "pxfm", +] + [[package]] name = "multimap" version = "0.10.1" diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index 87fece25cd..17901c7675 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -24,6 +24,7 @@ testing-etcd = [] block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:nix", "dep:aligned-vec"] cuda = ["dep:cudarc"] integration = ["dynamo-runtime/integration"] +media-nixl = ["dep:nixl-sys", "dep:dynamo-memory"] [[bench]] name = "tokenizer" @@ -33,9 +34,11 @@ harness = false name = "transfer_context_v2" harness = false required-features = ["block-manager", "testing-cuda"] + [dependencies] # repo dynamo-runtime = { workspace = true } +dynamo-memory = { path = "../memory", version = "0.7.0", optional = true } # workspace aho-corasick = "1.1" @@ -145,7 +148,7 @@ json-five = { version = "0.3" } # media loading in the preprocessor reqwest = { workspace = true } base64 = { version = "0.22" } -image = { version = "0.25" } +image = { version = "0.25", features = ["serde"] } tokio-rayon = {version = "2" } ndarray = { version = "0.16" } diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index e7374d8a3a..c3eead7691 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -27,7 +27,8 @@ use std::{collections::HashMap, pin::Pin, sync::Arc}; use tracing; use crate::model_card::{ModelDeploymentCard, ModelInfo}; -use crate::preprocessor::media::MediaLoader; +#[cfg(feature = "media-nixl")] +use crate::preprocessor::media::{MediaDecoder, MediaFetcher, MediaLoader}; use crate::preprocessor::prompt::OAIChatLikeRequest; use crate::protocols::common::preprocessor::{ MultimodalData, MultimodalDataMap, PreprocessedRequestBuilder, @@ -114,6 +115,7 @@ pub struct OpenAIPreprocessor { /// Per-model runtime configuration propagated to response generator (e.g., reasoning/tool parser) runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig, tool_call_parser: Option, + #[cfg(feature = "media-nixl")] media_loader: Option, } @@ -143,7 +145,13 @@ impl OpenAIPreprocessor { // // Initialize runtime config from the ModelDeploymentCard let runtime_config = mdc.runtime_config.clone(); - let media_loader = None; // TODO: enable with decoder config from MDC + + #[cfg(feature = "media-nixl")] + let media_loader = match mdc.media_decoder { + Some(media_decoder) => Some(MediaLoader::new(media_decoder, mdc.media_fetcher)?), + None => None, + }; + Ok(Arc::new(Self { formatter, tokenizer, @@ -151,6 +159,7 @@ impl OpenAIPreprocessor { mdcsum, runtime_config, tool_call_parser, + #[cfg(feature = "media-nixl")] media_loader, })) } @@ -279,7 +288,9 @@ impl OpenAIPreprocessor { let messages = request.messages(); let message_count = messages.len().unwrap_or(0); let mut media_map: MultimodalDataMap = HashMap::new(); - let mut fetch_tasks = Vec::new(); + #[cfg(feature = "media-nixl")] + let mut fetch_tasks: Vec<(String, ChatCompletionRequestUserMessageContentPart)> = + Vec::new(); for idx in 0..message_count { let msg = messages @@ -312,29 +323,39 @@ impl OpenAIPreprocessor { _ => continue, }; + #[cfg(feature = "media-nixl")] if self.media_loader.is_some() { fetch_tasks.push((type_str, content_part.clone())); - } else { - // No loader, just pass the URL through - media_map - .entry(type_str) - .or_default() - .push(MultimodalData::Url(url)); + continue; } + + //Fallback: ust pass the URL through + media_map + .entry(type_str) + .or_default() + .push(MultimodalData::Url(url)); } } // Execute all fetch tasks + #[cfg(feature = "media-nixl")] if !fetch_tasks.is_empty() { let loader = self.media_loader.as_ref().unwrap(); - let _results = futures::future::join_all( + let results = futures::future::join_all( fetch_tasks .iter() .map(|(_, content_part)| loader.fetch_and_decode_media_part(content_part)), ) .await; - // TODO: decode and pass NIXL descriptors to the media map + for ((type_str, _), result) in fetch_tasks.into_iter().zip(results.into_iter()) { + // if one item fails, errors the whole request, other items will be cleaned up by Drop + let rdma_descriptor = result?; + media_map + .entry(type_str) + .or_default() + .push(MultimodalData::Decoded(rdma_descriptor)); + } } if !media_map.is_empty() { diff --git a/lib/llm/src/preprocessor/media.rs b/lib/llm/src/preprocessor/media.rs index 0c0e3e6b12..65566c9937 100644 --- a/lib/llm/src/preprocessor/media.rs +++ b/lib/llm/src/preprocessor/media.rs @@ -4,7 +4,12 @@ mod common; mod decoders; mod loader; +mod rdma; pub use common::EncodedMediaData; pub use decoders::{Decoder, ImageDecoder, MediaDecoder}; pub use loader::{MediaFetcher, MediaLoader}; + +pub use rdma::{DecodedMediaData, RdmaMediaDataDescriptor}; +#[cfg(feature = "media-nixl")] +pub use rdma::{get_nixl_agent, get_nixl_metadata}; diff --git a/lib/llm/src/preprocessor/media/README.md b/lib/llm/src/preprocessor/media/README.md new file mode 100644 index 0000000000..fede33bc9f --- /dev/null +++ b/lib/llm/src/preprocessor/media/README.md @@ -0,0 +1,63 @@ +# Media decoding in the frontend + + +This component performs media download, base64 decoding, media decoding and NIXL registration. Today, this is used in the OpenAI preprocessor, to transform multimodal inputs (image_url, video_url, audio_url) into fully decoded data (pixel values, ...) accessible to the backends via NIXL. + +## Usage + +Media decoding is enabled when registering the MDC: + +Set HTTP download options: + +```python +from dynamo.llm import MediaFetcher +fetcher = MediaFetcher() +fetcher.user_agent("dynamo") +fetcher.timeout_ms(15000) +fetcher.allow_direct_ip(True) +fetcher.allow_direct_port(False) +fetcher.allowed_media_domains(["google.com"]) +``` + +Set media decoding options: + +```python +from dynamo.llm import MediaDecoder +decoder = MediaDecoder() +decoder.image_decoder({"max_image_width": 4096, "max_image_height": 4096, "max_alloc": 16*1024*1024}) +``` + +And register the LLM as usual, adding the media configuration: + +```python +register_llm( + ..., + media_decoder=decoder, + media_fetcher=fetcher, +) +``` + + +## TODOs + +### Modalities + +- [x] Image decoding +- [ ] Video decoding +- [ ] Audio decoding + +### Performance + +- [x] Image SW decoding +- [ ] Video HW decoding (NVDEC) +- [ ] JPEG HW decoding (nvJPEG) +- [ ] Sparse video sampling (seek-forward) +- [ ] Memory slab pre-allocation/registration + +### Memory management +- [ ] Memory spilling to lower storage tiers +- [ ] Early-free memory on client notifications + +### Misc +- [ ] Observability on performance, memory usage and input distributions +- [ ] Per-request decoding options diff --git a/lib/llm/src/preprocessor/media/decoders.rs b/lib/llm/src/preprocessor/media/decoders.rs index aa546915ec..984aab1ea2 100644 --- a/lib/llm/src/preprocessor/media/decoders.rs +++ b/lib/llm/src/preprocessor/media/decoders.rs @@ -2,52 +2,14 @@ // SPDX-License-Identifier: Apache-2.0 use anyhow::Result; +use serde::{Deserialize, Serialize}; use super::common::EncodedMediaData; -use ndarray::{ArrayBase, Dimension, OwnedRepr}; -mod image; +use super::rdma::DecodedMediaData; +pub mod image; pub use image::{ImageDecoder, ImageMetadata}; -#[derive(Debug)] -pub enum DecodedMediaMetadata { - #[allow(dead_code)] // used in followup MR - Image(ImageMetadata), -} - -#[derive(Debug, PartialEq, Eq)] -pub enum DataType { - UINT8, -} - -// Decoded media data (image RGB, video frames pixels, ...) -#[derive(Debug)] -pub struct DecodedMediaData { - #[allow(dead_code)] // used in followup MR - pub(crate) data: Vec, - #[allow(dead_code)] // used in followup MR - pub(crate) shape: Vec, - #[allow(dead_code)] // used in followup MR - pub(crate) dtype: DataType, - #[allow(dead_code)] // used in followup MR - pub(crate) metadata: Option, -} - -// convert Array{N} to DecodedMediaData -// TODO: Array1 for audio -impl From, D>> for DecodedMediaData { - fn from(array: ArrayBase, D>) -> Self { - let shape = array.shape().to_vec(); - let (data, _) = array.into_raw_vec_and_offset(); - Self { - data, - shape, - dtype: DataType::UINT8, - metadata: None, - } - } -} - #[async_trait::async_trait] pub trait Decoder: Clone + Send + 'static { fn decode(&self, data: EncodedMediaData) -> Result; @@ -67,3 +29,8 @@ pub struct MediaDecoder { pub image_decoder: ImageDecoder, // TODO: video, audio decoders } + +#[derive(Serialize, Deserialize, Clone, Copy, Debug)] +pub enum DecodedMediaMetadata { + Image(ImageMetadata), +} diff --git a/lib/llm/src/preprocessor/media/decoders/image.rs b/lib/llm/src/preprocessor/media/decoders/image.rs index e6c857d33b..09de4db8d2 100644 --- a/lib/llm/src/preprocessor/media/decoders/image.rs +++ b/lib/llm/src/preprocessor/media/decoders/image.rs @@ -6,14 +6,15 @@ use std::io::Cursor; use anyhow::Result; use image::{ColorType, GenericImageView, ImageFormat, ImageReader}; use ndarray::Array3; +use serde::{Deserialize, Serialize}; use super::super::common::EncodedMediaData; -use super::super::decoders::{DecodedMediaData, DecodedMediaMetadata}; -use super::Decoder; +use super::super::rdma::DecodedMediaData; +use super::{DecodedMediaMetadata, Decoder}; const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct ImageDecoder { #[serde(default)] @@ -36,18 +37,15 @@ impl Default for ImageDecoder { } #[allow(clippy::upper_case_acronyms)] -#[derive(Debug)] +#[derive(Serialize, Deserialize, Clone, Copy, Debug)] pub enum ImageLayout { HWC, } -#[derive(Debug)] +#[derive(Serialize, Deserialize, Clone, Copy, Debug)] pub struct ImageMetadata { - #[allow(dead_code)] // used in followup MR pub(crate) format: Option, - #[allow(dead_code)] // used in followup MR pub(crate) color_type: ColorType, - #[allow(dead_code)] // used in followup MR pub(crate) layout: ImageLayout, } @@ -78,8 +76,8 @@ impl Decoder for ImageDecoder { let (width, height) = img.dimensions(); let shape = (height as usize, width as usize, n_channels as usize); let array = Array3::from_shape_vec(shape, data)?; - let mut decoded: DecodedMediaData = array.into(); - decoded.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata { + let mut decoded: DecodedMediaData = array.try_into()?; + decoded.tensor_info.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata { format, color_type, layout: ImageLayout::HWC, @@ -90,7 +88,7 @@ impl Decoder for ImageDecoder { #[cfg(test)] mod tests { - use super::super::super::decoders::DataType; + use super::super::super::rdma::DataType; use super::*; use image::{DynamicImage, ImageBuffer}; use rstest::rstest; @@ -156,10 +154,10 @@ mod tests { let decoded = result.unwrap(); assert_eq!( - decoded.shape, + decoded.tensor_info.shape, vec![height as usize, width as usize, expected_channels as usize] ); - assert_eq!(decoded.dtype, DataType::UINT8); + assert_eq!(decoded.tensor_info.dtype, DataType::UINT8); } #[rstest] @@ -196,9 +194,12 @@ mod tests { format ); let decoded = result.unwrap(); - assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]); assert_eq!( - decoded.dtype, + decoded.tensor_info.shape, + vec![height as usize, width as usize, 3] + ); + assert_eq!( + decoded.tensor_info.dtype, DataType::UINT8, "dtype should be uint8 for case: {}", test_case @@ -236,11 +237,15 @@ mod tests { ); let decoded = result.unwrap(); - assert_eq!(decoded.shape.len(), 3, "Should have 3 dimensions"); - assert_eq!(decoded.shape[0], 1, "Height should be 1"); - assert_eq!(decoded.shape[1], 1, "Width should be 1"); assert_eq!( - decoded.dtype, + decoded.tensor_info.shape.len(), + 3, + "Should have 3 dimensions" + ); + assert_eq!(decoded.tensor_info.shape[0], 1, "Height should be 1"); + assert_eq!(decoded.tensor_info.shape[1], 1, "Width should be 1"); + assert_eq!( + decoded.tensor_info.dtype, DataType::UINT8, "dtype should be uint8 for {} channels {:?}", input_channels, diff --git a/lib/llm/src/preprocessor/media/loader.rs b/lib/llm/src/preprocessor/media/loader.rs index 91fc65d9bc..0d229d7437 100644 --- a/lib/llm/src/preprocessor/media/loader.rs +++ b/lib/llm/src/preprocessor/media/loader.rs @@ -8,8 +8,14 @@ use anyhow::Result; use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart; -use super::common::EncodedMediaData; -use super::decoders::{DecodedMediaData, Decoder, MediaDecoder}; +use super::decoders::MediaDecoder; +use super::rdma::RdmaMediaDataDescriptor; + +#[cfg(feature = "media-nixl")] +use { + super::common::EncodedMediaData, super::decoders::Decoder, super::rdma::get_nixl_agent, + dynamo_memory::nixl::NixlAgent, +}; const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo"; const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30); @@ -36,15 +42,19 @@ impl Default for MediaFetcher { } pub struct MediaLoader { + #[allow(dead_code)] media_decoder: MediaDecoder, + #[allow(dead_code)] http_client: reqwest::Client, media_fetcher: MediaFetcher, - // TODO: NIXL agent + #[cfg(feature = "media-nixl")] + nixl_agent: NixlAgent, } impl MediaLoader { - pub fn new(media_decoder: MediaDecoder, media_fetcher: MediaFetcher) -> Result { - let mut http_client_builder = + pub fn new(media_decoder: MediaDecoder, media_fetcher: Option) -> Result { + let media_fetcher = media_fetcher.unwrap_or_default(); + let mut http_client_builder: reqwest::ClientBuilder = reqwest::Client::builder().user_agent(&media_fetcher.user_agent); if let Some(timeout) = media_fetcher.timeout { @@ -53,10 +63,15 @@ impl MediaLoader { let http_client = http_client_builder.build()?; + #[cfg(feature = "media-nixl")] + let nixl_agent = get_nixl_agent()?; + Ok(Self { media_decoder, http_client, media_fetcher, + #[cfg(feature = "media-nixl")] + nixl_agent, }) } @@ -90,35 +105,43 @@ impl MediaLoader { &self, oai_content_part: &ChatCompletionRequestUserMessageContentPart, // TODO: request-level options - ) -> Result { - // fetch the media - // TODO: decode and NIXL-register - let decoded = match oai_content_part { - ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => { - let url = &image_part.image_url.url; - self.check_if_url_allowed(url)?; - let data = EncodedMediaData::from_url(url, &self.http_client).await?; - self.media_decoder.image_decoder.decode_async(data).await? - } - ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => { - let url = &video_part.video_url.url; - self.check_if_url_allowed(url)?; - EncodedMediaData::from_url(url, &self.http_client).await?; - anyhow::bail!("Video decoding is not supported yet"); - } - ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => { - anyhow::bail!("Audio decoding is not supported yet"); - } - _ => anyhow::bail!("Unsupported media type"), - }; + ) -> Result { + #[cfg(not(feature = "media-nixl"))] + anyhow::bail!( + "NIXL is not supported, cannot decode and register media data {oai_content_part:?}" + ); - Ok(decoded) + #[cfg(feature = "media-nixl")] + { + // fetch the media, decode and NIXL-register + let decoded = match oai_content_part { + ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => { + let url = &image_part.image_url.url; + self.check_if_url_allowed(url)?; + let data = EncodedMediaData::from_url(url, &self.http_client).await?; + self.media_decoder.image_decoder.decode_async(data).await? + } + ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => { + let url = &video_part.video_url.url; + self.check_if_url_allowed(url)?; + EncodedMediaData::from_url(url, &self.http_client).await?; + anyhow::bail!("Video decoding is not supported yet"); + } + ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => { + anyhow::bail!("Audio decoding is not supported yet"); + } + _ => anyhow::bail!("Unsupported media type"), + }; + + let rdma_descriptor = decoded.into_rdma_descriptor(&self.nixl_agent)?; + Ok(rdma_descriptor) + } } } -#[cfg(test)] +#[cfg(all(test, feature = "media-nixl"))] mod tests { - use super::super::decoders::DataType; + use super::super::rdma::DataType; use super::*; use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl}; @@ -143,7 +166,7 @@ mod tests { ..Default::default() }; - let loader = MediaLoader::new(media_decoder, fetcher).unwrap(); + let loader: MediaLoader = MediaLoader::new(media_decoder, fetcher).unwrap(); let image_url = ImageUrl::from(format!("{}/llm-optimize-deploy-graphic.png", server.url())); let content_part = ChatCompletionRequestUserMessageContentPart::ImageUrl( @@ -151,24 +174,48 @@ mod tests { ); let result = loader.fetch_and_decode_media_part(&content_part).await; - assert!( - result.is_ok(), - "Failed to fetch and decode image: {:?}", - result.err() - ); - let data = result.unwrap(); - assert_eq!(data.dtype, DataType::UINT8); + let descriptor = match result { + Ok(descriptor) => descriptor, + Err(e) if e.to_string().contains("NIXL agent is not available") => { + println!("test test_fetch_and_decode ... ignored (NIXL agent not available)"); + return; + } + Err(e) => panic!("Failed to fetch and decode image: {}", e), + }; + mock.assert_async().await; + assert_eq!(descriptor.tensor_info.dtype, DataType::UINT8); // Verify image dimensions: 1,999px × 1,125px (width × height) // Shape format is [height, width, channels] - assert_eq!(data.shape.len(), 3); - assert_eq!(data.shape[0], 1125, "Height should be 1125"); - assert_eq!(data.shape[1], 1999, "Width should be 1999"); - assert_eq!(data.shape[2], 4, "RGBA channels should be 4"); + assert_eq!(descriptor.tensor_info.shape.len(), 3); + assert_eq!( + descriptor.tensor_info.shape[0], 1125, + "Height should be 1125" + ); + assert_eq!( + descriptor.tensor_info.shape[1], 1999, + "Width should be 1999" + ); + assert_eq!( + descriptor.tensor_info.shape[2], 4, + "RGBA channels should be 4" + ); - mock.assert_async().await; + assert!( + descriptor.source_storage.is_some(), + "Source storage should be present" + ); + assert!( + descriptor.source_storage.unwrap().is_registered(), + "Source storage should be registered with NIXL" + ); } +} + +#[cfg(test)] +mod tests_non_nixl { + use super::*; #[test] fn test_direct_ip_blocked() { @@ -176,7 +223,7 @@ mod tests { allow_direct_ip: false, ..Default::default() }; - let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap(); + let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap(); let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap(); let result = loader.check_if_url_allowed(&url); @@ -196,7 +243,7 @@ mod tests { allow_direct_port: false, ..Default::default() }; - let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap(); + let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap(); let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap(); let result = loader.check_if_url_allowed(&url); @@ -220,7 +267,7 @@ mod tests { allowed_media_domains: Some(allowed_domains), ..Default::default() }; - let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap(); + let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap(); // Allowed domain should pass let url = url::Url::parse("https://trusted.com/image.jpg").unwrap(); diff --git a/lib/llm/src/preprocessor/media/rdma.rs b/lib/llm/src/preprocessor/media/rdma.rs new file mode 100644 index 0000000000..fc36c09d2f --- /dev/null +++ b/lib/llm/src/preprocessor/media/rdma.rs @@ -0,0 +1,130 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use anyhow::Result; +use ndarray::{ArrayBase, Dimension, OwnedRepr}; +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "media-nixl")] +use { + base64::{Engine as _, engine::general_purpose}, + dynamo_memory::SystemStorage, + dynamo_memory::nixl::{self, NixlAgent, NixlDescriptor, RegisteredView}, + std::sync::Arc, +}; + +use super::decoders::DecodedMediaMetadata; + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +pub enum DataType { + UINT8, +} + +// Common tensor metadata shared between decoded and RDMA descriptors +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct MediaTensorInfo { + pub(crate) shape: Vec, + pub(crate) dtype: DataType, + pub(crate) metadata: Option, +} + +// Decoded media data (image RGB, video frames pixels, ...) +#[derive(Debug)] +pub struct DecodedMediaData { + #[cfg(feature = "media-nixl")] + pub(crate) data: SystemStorage, + pub(crate) tensor_info: MediaTensorInfo, +} + +// Decoded media data NIXL descriptor (sent to the next step in the pipeline / NATS) + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct RdmaMediaDataDescriptor { + // b64 agent metadata + #[cfg(feature = "media-nixl")] + pub(crate) nixl_metadata: String, + // tensor descriptor + #[cfg(feature = "media-nixl")] + pub(crate) nixl_descriptor: NixlDescriptor, + + #[serde(flatten)] + pub(crate) tensor_info: MediaTensorInfo, + + // reference to the actual data, kept alive while the rdma descriptor is alive + #[serde(skip, default)] + #[allow(dead_code)] + #[cfg(feature = "media-nixl")] + pub(crate) source_storage: Option>>, +} + +impl DecodedMediaData { + #[cfg(feature = "media-nixl")] + pub fn into_rdma_descriptor(self, nixl_agent: &NixlAgent) -> Result { + let source_storage = self.data; + let registered = nixl::register_with_nixl(source_storage, nixl_agent, None) + .map_err(|_| anyhow::anyhow!("Failed to register storage with NIXL"))?; + + let nixl_descriptor = registered.descriptor(); + let nixl_metadata = get_nixl_metadata(nixl_agent, registered.storage())?; + + Ok(RdmaMediaDataDescriptor { + nixl_metadata, + nixl_descriptor, + tensor_info: self.tensor_info, + // Keep registered storage alive + source_storage: Some(Arc::new(registered)), + }) + } +} + +// convert Array{N} to DecodedMediaData +// TODO: Array1 for audio + +impl TryFrom, D>> for DecodedMediaData { + type Error = anyhow::Error; + + fn try_from(array: ArrayBase, D>) -> Result { + let shape = array.shape().to_vec(); + + #[cfg(feature = "media-nixl")] + let (data_vec, _) = array.into_raw_vec_and_offset(); + #[cfg(feature = "media-nixl")] + let mut storage = SystemStorage::new(data_vec.len())?; + #[cfg(feature = "media-nixl")] + unsafe { + std::ptr::copy_nonoverlapping(data_vec.as_ptr(), storage.as_mut_ptr(), data_vec.len()); + } + + Ok(Self { + #[cfg(feature = "media-nixl")] + data: storage, + tensor_info: MediaTensorInfo { + shape, + dtype: DataType::UINT8, + metadata: None, + }, + }) + } +} + +// Get NIXL metadata for a descriptor +// Avoids cross-request leak possibility and reduces metadata size +// TODO: pre-allocate a fixed NIXL-registered RAM pool so metadata can be cached on the target? +#[cfg(feature = "media-nixl")] +pub fn get_nixl_metadata(agent: &NixlAgent, _storage: &SystemStorage) -> Result { + // WAR: Until https://github.com/ai-dynamo/nixl/pull/970 is merged, can't use get_local_partial_md + let nixl_md = agent.raw_agent().get_local_md()?; + // let mut reg_desc_list = RegDescList::new(MemType::Dram)?; + // reg_desc_list.add_storage_desc(storage)?; + // let nixl_partial_md = agent.raw_agent().get_local_partial_md(®_desc_list, None)?; + + let b64_encoded = general_purpose::STANDARD.encode(&nixl_md); + Ok(format!("b64:{}", b64_encoded)) +} + +#[cfg(feature = "media-nixl")] +pub fn get_nixl_agent() -> Result { + let name = format!("media-loader-{}", uuid::Uuid::new_v4()); + let nixl_agent = NixlAgent::with_backends(&name, &["UCX"])?; + Ok(nixl_agent) +} diff --git a/lib/llm/src/protocols/common/preprocessor.rs b/lib/llm/src/protocols/common/preprocessor.rs index 71260da5d2..fc8f01cac1 100644 --- a/lib/llm/src/protocols/common/preprocessor.rs +++ b/lib/llm/src/protocols/common/preprocessor.rs @@ -6,12 +6,15 @@ use serde::{Deserialize, Serialize}; use super::{OutputOptions, SamplingOptions, StopConditions}; use crate::kv_router::RouterConfigOverride; +#[cfg(feature = "media-nixl")] +use crate::preprocessor::media::RdmaMediaDataDescriptor; use crate::protocols::TokenIdType; #[derive(Serialize, Deserialize, Debug, Clone)] pub enum MultimodalData { Url(url::Url), - // TODO: Decoded(DecodedMediaData), + #[cfg(feature = "media-nixl")] + Decoded(RdmaMediaDataDescriptor), } // multimodal map containing {mm_part_type: [data...]} @@ -31,6 +34,7 @@ pub struct PreprocessedRequest { #[builder(default)] #[serde(default, skip_serializing_if = "Option::is_none")] pub multi_modal_data: Option, + /// StopConditions are conditions that the inference engine will use to stop generation. pub stop_conditions: StopConditions,