From 945af997a497946368ae20fac8dec486258f05ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E6=8C=AF=E5=8D=8E?= <1823301126@qq.com> Date: Mon, 2 Mar 2026 23:23:14 +0800 Subject: [PATCH 1/2] feat(kernel): add vocab embedding CUDA kernels Add four vectorized CUDA embedding kernels: - embedding_lookup: standard token embedding - embedding_lookup_with_image: token + image embedding fusion - assemble_deepstack_embedding: extract image-only embeddings - embedding_lookup_multimodal: text + image + audio embedding All 17 tests passed. --- .../mllm_kernel/cuda/csrc/vocab_embedding.cuh | 428 ++++++++++++++ mllm-kernel/mllm_kernel/cuda/jit/__init__.py | 18 +- .../mllm_kernel/cuda/jit/vocab_embedding.py | 307 ++++++++++ mllm-kernel/tests/test_vocab_embedding.py | 547 ++++++++++++++++++ 4 files changed, 1297 insertions(+), 3 deletions(-) create mode 100755 mllm-kernel/mllm_kernel/cuda/csrc/vocab_embedding.cuh create mode 100755 mllm-kernel/mllm_kernel/cuda/jit/vocab_embedding.py create mode 100755 mllm-kernel/tests/test_vocab_embedding.py diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/vocab_embedding.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/vocab_embedding.cuh new file mode 100755 index 00000000..a2ad3b32 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/vocab_embedding.cuh @@ -0,0 +1,428 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +// +// Embedding kernels migrated from TensorRT-Edge-LLM. +// Reference: https://github.com/NVIDIA/TensorRT-Edge-LLM/tree/main/cpp/kernels/embeddingKernels +// +// Supported operations: +// 1. embedding_lookup — standard token embedding +// 2. embedding_lookup_with_image — token + image embedding fusion +// 3. assemble_deepstack_embedding — extract image-only embeddings +// 4. embedding_lookup_multimodal — text + image + audio embedding + +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include + +namespace { + +// ─────────────────────────────────────────────────────────────── +// Constants +// ─────────────────────────────────────────────────────────────── + +constexpr uint32_t kEmbNumWarps = 4; +constexpr uint32_t kEmbBlockSize = kEmbNumWarps * device::kWarpThreads; + +// ─────────────────────────────────────────────────────────────── +// Vectorised warp-level row copy / zero +// ─────────────────────────────────────────────────────────────── + +namespace detail { + +template +__device__ __forceinline__ void warp_copy_row(const void* __restrict__ src, + void* __restrict__ dst, + int64_t num_vecs) { + const int lane = threadIdx.x % device::kWarpThreads; + const auto* __restrict__ s = static_cast(src); + auto* __restrict__ d = static_cast(dst); + for (int64_t i = lane; i < num_vecs; i += device::kWarpThreads) { + d[i] = s[i]; + } +} + +template +__device__ __forceinline__ void warp_zero_row(void* __restrict__ dst, + int64_t num_vecs) { + const int lane = threadIdx.x % device::kWarpThreads; + auto* __restrict__ d = static_cast(dst); + const Vec zero{}; + for (int64_t i = lane; i < num_vecs; i += device::kWarpThreads) { + d[i] = zero; + } +} + +} // namespace detail + +__device__ __forceinline__ void copy_or_zero_row(const void* __restrict__ src, + void* __restrict__ dst, + int64_t row_bytes) { + if (row_bytes % 16 == 0) { + const int64_t n = row_bytes / 16; + if (src) detail::warp_copy_row(src, dst, n); + else detail::warp_zero_row(dst, n); + } else if (row_bytes % 8 == 0) { + const int64_t n = row_bytes / 8; + if (src) detail::warp_copy_row(src, dst, n); + else detail::warp_zero_row(dst, n); + } else { + const int64_t n = row_bytes / 4; + if (src) detail::warp_copy_row(src, dst, n); + else detail::warp_zero_row(dst, n); + } +} + +// ─────────────────────────────────────────────────────────────── +// Parameter blocks (passed via __grid_constant__) +// ─────────────────────────────────────────────────────────────── + +struct EmbeddingLookupParams { + void* __restrict__ output; + const void* __restrict__ input_ids; + const void* __restrict__ embedding_table; + int64_t num_tokens; + int64_t stride_bytes; + int32_t vocab_size; +}; + +struct EmbeddingLookupWithImageParams { + void* __restrict__ output; + const void* __restrict__ input_ids; + const void* __restrict__ embedding_table; + const void* __restrict__ image_embeds; + int64_t num_tokens; + int64_t stride_bytes; + int32_t vocab_size; + int64_t image_token_len; +}; + +struct AssembleDeepstackParams { + void* __restrict__ output; + const void* __restrict__ input_ids; + const void* __restrict__ deepstack_features; + int64_t num_tokens; + int64_t stride_bytes; + int32_t vocab_size; + int64_t num_image_tokens; +}; + +struct EmbeddingMultimodalParams { + void* __restrict__ output; + const void* __restrict__ input_ids; + const void* __restrict__ embedding_table; + const void* __restrict__ multimodal_indices; + const void* __restrict__ image_embeds; + const void* __restrict__ audio_embeds; + int64_t num_tokens; + int64_t stride_bytes; + int32_t vocab_size; + int32_t image_token_id; + int64_t image_token_len; + int32_t audio_token_id; + int64_t audio_token_len; +}; + +// ─────────────────────────────────────────────────────────────── +// Kernel 1: Standard Embedding Lookup +// ─────────────────────────────────────────────────────────────── + +__global__ void embedding_lookup_kernel( + const __grid_constant__ EmbeddingLookupParams params) { + const uint32_t warp_id = blockIdx.x * blockDim.y + threadIdx.y; + if (warp_id >= params.num_tokens) return; + + const auto token_id = static_cast(params.input_ids)[warp_id]; + + const void* src = nullptr; + if (token_id >= 0 && token_id < params.vocab_size) { + src = device::pointer::offset(params.embedding_table, + static_cast(token_id) * params.stride_bytes); + } + auto* dst = device::pointer::offset(params.output, + static_cast(warp_id) * params.stride_bytes); + + copy_or_zero_row(src, dst, params.stride_bytes); +} + +// ─────────────────────────────────────────────────────────────── +// Kernel 2: Embedding Lookup with Image Insertion +// ─────────────────────────────────────────────────────────────── + +__global__ void embedding_lookup_with_image_kernel( + const __grid_constant__ EmbeddingLookupWithImageParams params) { + const uint32_t warp_id = blockIdx.x * blockDim.y + threadIdx.y; + if (warp_id >= params.num_tokens) return; + + const auto token_id = static_cast(params.input_ids)[warp_id]; + + const void* src = nullptr; + if (token_id >= params.vocab_size) { + const int32_t visual_id = token_id - params.vocab_size; + if (visual_id < params.image_token_len) { + src = device::pointer::offset(params.image_embeds, + static_cast(visual_id) * params.stride_bytes); + } + } else if (token_id >= 0) { + src = device::pointer::offset(params.embedding_table, + static_cast(token_id) * params.stride_bytes); + } + auto* dst = device::pointer::offset(params.output, + static_cast(warp_id) * params.stride_bytes); + + copy_or_zero_row(src, dst, params.stride_bytes); +} + +// ─────────────────────────────────────────────────────────────── +// Kernel 3: Assemble Deepstack Embedding +// ─────────────────────────────────────────────────────────────── + +__global__ void assemble_deepstack_embedding_kernel( + const __grid_constant__ AssembleDeepstackParams params) { + const uint32_t warp_id = blockIdx.x * blockDim.y + threadIdx.y; + if (warp_id >= params.num_tokens) return; + + const auto token_id = static_cast(params.input_ids)[warp_id]; + + const void* src = nullptr; + if (token_id >= params.vocab_size) { + const int32_t ds_idx = token_id - params.vocab_size; + if (ds_idx < params.num_image_tokens) { + src = device::pointer::offset(params.deepstack_features, + static_cast(ds_idx) * params.stride_bytes); + } + } + auto* dst = device::pointer::offset(params.output, + static_cast(warp_id) * params.stride_bytes); + + copy_or_zero_row(src, dst, params.stride_bytes); +} + +// ─────────────────────────────────────────────────────────────── +// Kernel 4: Multimodal Embedding Lookup +// ─────────────────────────────────────────────────────────────── + +__global__ void embedding_lookup_multimodal_kernel( + const __grid_constant__ EmbeddingMultimodalParams params) { + const uint32_t warp_id = blockIdx.x * blockDim.y + threadIdx.y; + if (warp_id >= params.num_tokens) return; + + const auto token_id = static_cast(params.input_ids)[warp_id]; + + const void* src = nullptr; + if (params.image_embeds != nullptr && token_id == params.image_token_id) { + const auto idx = static_cast(params.multimodal_indices)[warp_id]; + if (idx >= 0 && idx < params.image_token_len) { + src = device::pointer::offset(params.image_embeds, + static_cast(idx) * params.stride_bytes); + } + } else if (params.audio_embeds != nullptr && token_id == params.audio_token_id) { + const auto idx = static_cast(params.multimodal_indices)[warp_id]; + if (idx >= 0 && idx < params.audio_token_len) { + src = device::pointer::offset(params.audio_embeds, + static_cast(idx) * params.stride_bytes); + } + } else if (token_id >= 0 && token_id < params.vocab_size) { + src = device::pointer::offset(params.embedding_table, + static_cast(token_id) * params.stride_bytes); + } + auto* dst = device::pointer::offset(params.output, + static_cast(warp_id) * params.stride_bytes); + + copy_or_zero_row(src, dst, params.stride_bytes); +} + +// ─────────────────────────────────────────────────────────────── +// Host-side launch wrappers +// ─────────────────────────────────────────────────────────────── + +void embedding_lookup( + tvm::ffi::TensorView output, + tvm::ffi::TensorView input_ids, + tvm::ffi::TensorView embedding_table) { + using namespace mllm_kernel::host; + + auto N = SymbolicSize{"num_tokens"}; + auto V = SymbolicSize{"vocab_size"}; + auto H = SymbolicSize{"hidden_size"}; + auto dtype = SymbolicDType{}; + auto device = SymbolicDevice{}; + device.set_options(); + dtype.set_options(); + + (void)TensorMatcher({N}).with_dtype().with_device().verify(input_ids); + (void)TensorMatcher({V, H}).with_dtype(dtype).with_device(device).verify(embedding_table); + (void)TensorMatcher({N, H}).with_dtype(dtype).with_device(device).verify(output); + + const int64_t dtype_size = dtype_bytes(dtype.unwrap()); + const auto num_tokens = static_cast(N.unwrap()); + const auto stride_bytes = H.unwrap() * dtype_size; + + RuntimeCheck(stride_bytes % 4 == 0, + "stride_bytes must be at least 4-byte aligned, got ", stride_bytes); + + const auto params = EmbeddingLookupParams{ + .output = output.data_ptr(), + .input_ids = input_ids.data_ptr(), + .embedding_table = embedding_table.data_ptr(), + .num_tokens = static_cast(num_tokens), + .stride_bytes = stride_bytes, + .vocab_size = static_cast(V.unwrap()), + }; + + const dim3 block(device::kWarpThreads, kEmbNumWarps); + const auto grid = div_ceil(num_tokens, kEmbNumWarps); + LaunchKernel(grid, block, device.unwrap())(embedding_lookup_kernel, params); +} + +void embedding_lookup_with_image( + tvm::ffi::TensorView output, + tvm::ffi::TensorView input_ids, + tvm::ffi::TensorView embedding_table, + tvm::ffi::TensorView image_embeds) { + using namespace mllm_kernel::host; + + auto N = SymbolicSize{"num_tokens"}; + auto V = SymbolicSize{"vocab_size"}; + auto H = SymbolicSize{"hidden_size"}; + auto I = SymbolicSize{"image_token_len"}; + auto dtype = SymbolicDType{}; + auto device = SymbolicDevice{}; + device.set_options(); + dtype.set_options(); + + (void)TensorMatcher({N}).with_dtype().with_device().verify(input_ids); + (void)TensorMatcher({V, H}).with_dtype(dtype).with_device(device).verify(embedding_table); + (void)TensorMatcher({I, H}).with_dtype(dtype).with_device(device).verify(image_embeds); + (void)TensorMatcher({N, H}).with_dtype(dtype).with_device(device).verify(output); + + const int64_t dtype_size = dtype_bytes(dtype.unwrap()); + const auto num_tokens = static_cast(N.unwrap()); + const auto stride_bytes = H.unwrap() * dtype_size; + + RuntimeCheck(stride_bytes % 4 == 0, + "stride_bytes must be at least 4-byte aligned, got ", stride_bytes); + + const auto params = EmbeddingLookupWithImageParams{ + .output = output.data_ptr(), + .input_ids = input_ids.data_ptr(), + .embedding_table = embedding_table.data_ptr(), + .image_embeds = image_embeds.data_ptr(), + .num_tokens = static_cast(num_tokens), + .stride_bytes = stride_bytes, + .vocab_size = static_cast(V.unwrap()), + .image_token_len = I.unwrap(), + }; + + const dim3 block(device::kWarpThreads, kEmbNumWarps); + const auto grid = div_ceil(num_tokens, kEmbNumWarps); + LaunchKernel(grid, block, device.unwrap())(embedding_lookup_with_image_kernel, params); +} + +void assemble_deepstack_embedding( + tvm::ffi::TensorView output, + tvm::ffi::TensorView input_ids, + tvm::ffi::TensorView deepstack_features, + int vocab_size) { + using namespace mllm_kernel::host; + + auto N = SymbolicSize{"num_tokens"}; + auto F = SymbolicSize{"num_image_tokens"}; + auto H = SymbolicSize{"hidden_size"}; + auto dtype = SymbolicDType{}; + auto device = SymbolicDevice{}; + device.set_options(); + dtype.set_options(); + + (void)TensorMatcher({N}).with_dtype().with_device().verify(input_ids); + (void)TensorMatcher({F, H}).with_dtype(dtype).with_device(device).verify(deepstack_features); + (void)TensorMatcher({N, H}).with_dtype(dtype).with_device(device).verify(output); + + const int64_t dtype_size = dtype_bytes(dtype.unwrap()); + const auto num_tokens = static_cast(N.unwrap()); + const auto stride_bytes = H.unwrap() * dtype_size; + + RuntimeCheck(stride_bytes % 4 == 0, + "stride_bytes must be at least 4-byte aligned, got ", stride_bytes); + + const auto params = AssembleDeepstackParams{ + .output = output.data_ptr(), + .input_ids = input_ids.data_ptr(), + .deepstack_features = deepstack_features.data_ptr(), + .num_tokens = static_cast(num_tokens), + .stride_bytes = stride_bytes, + .vocab_size = static_cast(vocab_size), + .num_image_tokens = F.unwrap(), + }; + + const dim3 block(device::kWarpThreads, kEmbNumWarps); + const auto grid = div_ceil(num_tokens, kEmbNumWarps); + LaunchKernel(grid, block, device.unwrap())(assemble_deepstack_embedding_kernel, params); +} + +void embedding_lookup_multimodal( + tvm::ffi::TensorView output, + tvm::ffi::TensorView input_ids, + tvm::ffi::TensorView embedding_table, + tvm::ffi::TensorView multimodal_indices, + tvm::ffi::TensorView image_embeds, + tvm::ffi::TensorView audio_embeds, + int image_token_id, + int audio_token_id) { + using namespace mllm_kernel::host; + + auto N = SymbolicSize{"num_tokens"}; + auto V = SymbolicSize{"vocab_size"}; + auto H = SymbolicSize{"hidden_size"}; + auto IL = SymbolicSize{"image_token_len"}; + auto AL = SymbolicSize{"audio_token_len"}; + auto dtype = SymbolicDType{}; + auto device = SymbolicDevice{}; + device.set_options(); + dtype.set_options(); + + (void)TensorMatcher({N}).with_dtype().with_device().verify(input_ids); + (void)TensorMatcher({V, H}).with_dtype(dtype).with_device(device).verify(embedding_table); + (void)TensorMatcher({N}).with_dtype().with_device().verify(multimodal_indices); + (void)TensorMatcher({IL, H}).with_dtype(dtype).with_device(device).verify(image_embeds); + (void)TensorMatcher({AL, H}).with_dtype(dtype).with_device(device).verify(audio_embeds); + (void)TensorMatcher({N, H}).with_dtype(dtype).with_device(device).verify(output); + + const int64_t dtype_size = dtype_bytes(dtype.unwrap()); + const auto num_tokens = static_cast(N.unwrap()); + const auto stride_bytes = H.unwrap() * dtype_size; + const auto image_token_len = IL.unwrap(); + const auto audio_token_len = AL.unwrap(); + + RuntimeCheck(stride_bytes % 4 == 0, + "stride_bytes must be at least 4-byte aligned, got ", stride_bytes); + + const auto params = EmbeddingMultimodalParams{ + .output = output.data_ptr(), + .input_ids = input_ids.data_ptr(), + .embedding_table = embedding_table.data_ptr(), + .multimodal_indices = multimodal_indices.data_ptr(), + .image_embeds = (image_token_len > 0) ? image_embeds.data_ptr() : nullptr, + .audio_embeds = (audio_token_len > 0) ? audio_embeds.data_ptr() : nullptr, + .num_tokens = static_cast(num_tokens), + .stride_bytes = stride_bytes, + .vocab_size = static_cast(V.unwrap()), + .image_token_id = static_cast(image_token_id), + .image_token_len = image_token_len, + .audio_token_id = static_cast(audio_token_id), + .audio_token_len = audio_token_len, + }; + + const dim3 block(device::kWarpThreads, kEmbNumWarps); + const auto grid = div_ceil(num_tokens, kEmbNumWarps); + LaunchKernel(grid, block, device.unwrap())(embedding_lookup_multimodal_kernel, params); +} + +} // namespace \ No newline at end of file diff --git a/mllm-kernel/mllm_kernel/cuda/jit/__init__.py b/mllm-kernel/mllm_kernel/cuda/jit/__init__.py index 696e73ea..db83897e 100644 --- a/mllm-kernel/mllm_kernel/cuda/jit/__init__.py +++ b/mllm-kernel/mllm_kernel/cuda/jit/__init__.py @@ -1,3 +1,15 @@ -from .add_constant import add_constant - -__all__ = ["add_constant"] +from .add_constant import add_constant +from .vocab_embedding import ( + assemble_deepstack_embedding, + embedding_lookup, + embedding_lookup_multimodal, + embedding_lookup_with_image, +) + +__all__ = [ + "add_constant", + "assemble_deepstack_embedding", + "embedding_lookup", + "embedding_lookup_multimodal", + "embedding_lookup_with_image", +] \ No newline at end of file diff --git a/mllm-kernel/mllm_kernel/cuda/jit/vocab_embedding.py b/mllm-kernel/mllm_kernel/cuda/jit/vocab_embedding.py new file mode 100755 index 00000000..90985939 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/jit/vocab_embedding.py @@ -0,0 +1,307 @@ +# Copyright (c) MLLM Team. +# Licensed under the MIT License. +# Embedding kernels migrated from TensorRT-Edge-LLM. +# Reference: https://github.com/NVIDIA/TensorRT-Edge-LLM/tree/main/cpp/kernels/embeddingKernels + +from __future__ import annotations + +from typing import Optional + +import torch + +from mllm_kernel.jit_utils import jit + + +# ============================================================================ +# Op 1: embedding_lookup +# ============================================================================ + + +@jit( + device="cuda", + cuda_files=["vocab_embedding.cuh"], + cpp_wrappers=[], + cuda_wrappers=[("embedding_lookup", "embedding_lookup")], + func_name="embedding_lookup", +) +def _embedding_lookup_kernel( + compiled_module, output: torch.Tensor, input_ids: torch.Tensor, embedding_table: torch.Tensor +) -> None: + compiled_module.embedding_lookup(output, input_ids, embedding_table) + + +def embedding_lookup(input_ids: torch.Tensor, embedding_table: torch.Tensor) -> torch.Tensor: + """ + Standard embedding lookup using vectorized CUDA kernel. + + Maps each token ID in input_ids to its corresponding row in embedding_table. + Out-of-range token IDs produce zero vectors. + + Args: + input_ids: Token IDs, shape [num_tokens], dtype int32, device cuda. + embedding_table: Embedding weight matrix, shape [vocab_size, hidden_size], + dtype float16 or bfloat16, device cuda. + + Returns: + Embedded output, shape [num_tokens, hidden_size], dtype matching embedding_table. + + Example: + >>> import torch + >>> from mllm_kernel.cuda.jit.vocab_embedding import embedding_lookup + >>> ids = torch.tensor([0, 3, 7], dtype=torch.int32, device="cuda") + >>> table = torch.randn(100, 256, dtype=torch.float16, device="cuda") + >>> out = embedding_lookup(ids, table) + >>> assert out.shape == (3, 256) + """ + if input_ids.dtype != torch.int32: + input_ids = input_ids.to(torch.int32) + if not input_ids.is_contiguous(): + input_ids = input_ids.contiguous() + if not embedding_table.is_contiguous(): + embedding_table = embedding_table.contiguous() + + num_tokens = input_ids.shape[0] + hidden_size = embedding_table.shape[1] + output = torch.empty(num_tokens, hidden_size, dtype=embedding_table.dtype, device=input_ids.device) + + _embedding_lookup_kernel(output, input_ids, embedding_table) + return output + + +# ============================================================================ +# Op 2: embedding_lookup_with_image +# ============================================================================ + + +@jit( + device="cuda", + cuda_files=["vocab_embedding.cuh"], + cpp_wrappers=[], + cuda_wrappers=[("embedding_lookup_with_image", "embedding_lookup_with_image")], + func_name="embedding_lookup_with_image", +) +def _embedding_lookup_with_image_kernel( + compiled_module, + output: torch.Tensor, + input_ids: torch.Tensor, + embedding_table: torch.Tensor, + image_embeds: torch.Tensor, +) -> None: + compiled_module.embedding_lookup_with_image(output, input_ids, embedding_table, image_embeds) + + +def embedding_lookup_with_image( + input_ids: torch.Tensor, + embedding_table: torch.Tensor, + image_embeds: torch.Tensor, +) -> torch.Tensor: + """ + Embedding lookup with image embedding insertion. + + For token IDs in [0, vocab_size): lookup from embedding_table. + For token IDs >= vocab_size: lookup from image_embeds at index (token_id - vocab_size). + + Args: + input_ids: Token IDs, shape [num_tokens], dtype int32, device cuda. + embedding_table: Text embedding table, shape [vocab_size, hidden_size], + dtype float16 or bfloat16, device cuda. + image_embeds: Image embeddings, shape [image_token_len, hidden_size], + dtype matching embedding_table, device cuda. + + Returns: + Embedded output, shape [num_tokens, hidden_size], dtype matching embedding_table. + + Example: + >>> import torch + >>> from mllm_kernel.cuda.jit.vocab_embedding import embedding_lookup_with_image + >>> ids = torch.tensor([0, 100, 101], dtype=torch.int32, device="cuda") + >>> table = torch.randn(100, 256, dtype=torch.float16, device="cuda") + >>> img = torch.randn(10, 256, dtype=torch.float16, device="cuda") + >>> out = embedding_lookup_with_image(ids, table, img) + >>> assert out.shape == (3, 256) + """ + if input_ids.dtype != torch.int32: + input_ids = input_ids.to(torch.int32) + if not input_ids.is_contiguous(): + input_ids = input_ids.contiguous() + if not embedding_table.is_contiguous(): + embedding_table = embedding_table.contiguous() + if not image_embeds.is_contiguous(): + image_embeds = image_embeds.contiguous() + + num_tokens = input_ids.shape[0] + hidden_size = embedding_table.shape[1] + output = torch.empty(num_tokens, hidden_size, dtype=embedding_table.dtype, device=input_ids.device) + + _embedding_lookup_with_image_kernel(output, input_ids, embedding_table, image_embeds) + return output + + +# ============================================================================ +# Op 3: assemble_deepstack_embedding +# ============================================================================ + + +@jit( + device="cuda", + cuda_files=["vocab_embedding.cuh"], + cpp_wrappers=[], + cuda_wrappers=[("assemble_deepstack_embedding", "assemble_deepstack_embedding")], + func_name="assemble_deepstack_embedding", +) +def _assemble_deepstack_embedding_kernel( + compiled_module, + output: torch.Tensor, + input_ids: torch.Tensor, + deepstack_features: torch.Tensor, + vocab_size: int, +) -> None: + compiled_module.assemble_deepstack_embedding(output, input_ids, deepstack_features, vocab_size) + + +def assemble_deepstack_embedding( + input_ids: torch.Tensor, + deepstack_features: torch.Tensor, + vocab_size: int, +) -> torch.Tensor: + """ + Extract image-only embeddings from deepstack features. + + Token IDs >= vocab_size: lookup from deepstack_features at index (token_id - vocab_size). + Token IDs < vocab_size: zero output (text tokens handled elsewhere). + + Args: + input_ids: Token IDs, shape [num_tokens], dtype int32, device cuda. + deepstack_features: Deepstack feature embeddings, shape [num_image_tokens, hidden_size], + dtype float16 or bfloat16, device cuda. + vocab_size: Vocabulary size (threshold for image token detection). + + Returns: + Embedded output, shape [num_tokens, hidden_size], dtype matching deepstack_features. + + Example: + >>> import torch + >>> from mllm_kernel.cuda.jit.vocab_embedding import assemble_deepstack_embedding + >>> ids = torch.tensor([0, 100, 101], dtype=torch.int32, device="cuda") + >>> features = torch.randn(10, 256, dtype=torch.float16, device="cuda") + >>> out = assemble_deepstack_embedding(ids, features, vocab_size=100) + >>> assert out.shape == (3, 256) + """ + if input_ids.dtype != torch.int32: + input_ids = input_ids.to(torch.int32) + if not input_ids.is_contiguous(): + input_ids = input_ids.contiguous() + if not deepstack_features.is_contiguous(): + deepstack_features = deepstack_features.contiguous() + + num_tokens = input_ids.shape[0] + hidden_size = deepstack_features.shape[1] + output = torch.empty(num_tokens, hidden_size, dtype=deepstack_features.dtype, device=input_ids.device) + + _assemble_deepstack_embedding_kernel(output, input_ids, deepstack_features, vocab_size) + return output + + +# ============================================================================ +# Op 4: embedding_lookup_multimodal +# ============================================================================ + + +@jit( + device="cuda", + cuda_files=["vocab_embedding.cuh"], + cpp_wrappers=[], + cuda_wrappers=[("embedding_lookup_multimodal", "embedding_lookup_multimodal")], + func_name="embedding_lookup_multimodal", +) +def _embedding_lookup_multimodal_kernel( + compiled_module, + output: torch.Tensor, + input_ids: torch.Tensor, + embedding_table: torch.Tensor, + multimodal_indices: torch.Tensor, + image_embeds: torch.Tensor, + audio_embeds: torch.Tensor, + image_token_id: int, + audio_token_id: int, +) -> None: + compiled_module.embedding_lookup_multimodal( + output, input_ids, embedding_table, multimodal_indices, + image_embeds, audio_embeds, image_token_id, audio_token_id + ) + + +def embedding_lookup_multimodal( + input_ids: torch.Tensor, + embedding_table: torch.Tensor, + multimodal_indices: torch.Tensor, + image_embeds: Optional[torch.Tensor] = None, + audio_embeds: Optional[torch.Tensor] = None, + image_token_id: int = -1, + audio_token_id: int = -1, +) -> torch.Tensor: + """ + Multimodal embedding lookup supporting text, image, and audio tokens. + + Uses multimodal_indices to determine the embedding index for image/audio tokens. + Token IDs matching image_token_id or audio_token_id are looked up from their + respective embedding tables. Other token IDs are looked up from embedding_table. + + Args: + input_ids: Token IDs, shape [num_tokens], dtype int32, device cuda. + embedding_table: Text embedding table, shape [vocab_size, hidden_size], + dtype float16 or bfloat16, device cuda. + multimodal_indices: Indices for image/audio embeddings, shape [num_tokens], + dtype int32, device cuda. + image_embeds: Image embeddings, shape [image_token_len, hidden_size], + dtype matching embedding_table, device cuda. Optional. + audio_embeds: Audio embeddings, shape [audio_token_len, hidden_size], + dtype matching embedding_table, device cuda. Optional. + image_token_id: Special token ID for image tokens. Default -1 (disabled). + audio_token_id: Special token ID for audio tokens. Default -1 (disabled). + + Returns: + Embedded output, shape [num_tokens, hidden_size], dtype matching embedding_table. + + Example: + >>> import torch + >>> from mllm_kernel.cuda.jit.vocab_embedding import embedding_lookup_multimodal + >>> ids = torch.tensor([0, 32000, 32001, 1], dtype=torch.int32, device="cuda") + >>> table = torch.randn(32000, 256, dtype=torch.float16, device="cuda") + >>> indices = torch.tensor([0, 0, 1, 0], dtype=torch.int32, device="cuda") + >>> img = torch.randn(10, 256, dtype=torch.float16, device="cuda") + >>> aud = torch.randn(5, 256, dtype=torch.float16, device="cuda") + >>> out = embedding_lookup_multimodal(ids, table, indices, img, aud, 32000, 32001) + >>> assert out.shape == (4, 256) + """ + if input_ids.dtype != torch.int32: + input_ids = input_ids.to(torch.int32) + if not input_ids.is_contiguous(): + input_ids = input_ids.contiguous() + if not embedding_table.is_contiguous(): + embedding_table = embedding_table.contiguous() + if not multimodal_indices.is_contiguous(): + multimodal_indices = multimodal_indices.contiguous() + + # Handle optional image_embeds and audio_embeds [1] + if image_embeds is None: + image_embeds = torch.empty(0, embedding_table.shape[1], dtype=embedding_table.dtype, device=input_ids.device) + else: + if not image_embeds.is_contiguous(): + image_embeds = image_embeds.contiguous() + + if audio_embeds is None: + audio_embeds = torch.empty(0, embedding_table.shape[1], dtype=embedding_table.dtype, device=input_ids.device) + else: + if not audio_embeds.is_contiguous(): + audio_embeds = audio_embeds.contiguous() + + num_tokens = input_ids.shape[0] + hidden_size = embedding_table.shape[1] + output = torch.empty(num_tokens, hidden_size, dtype=embedding_table.dtype, device=input_ids.device) + + _embedding_lookup_multimodal_kernel( + output, input_ids, embedding_table, multimodal_indices, + image_embeds, audio_embeds, image_token_id, audio_token_id + ) + return output \ No newline at end of file diff --git a/mllm-kernel/tests/test_vocab_embedding.py b/mllm-kernel/tests/test_vocab_embedding.py new file mode 100755 index 00000000..0def56a8 --- /dev/null +++ b/mllm-kernel/tests/test_vocab_embedding.py @@ -0,0 +1,547 @@ +from __future__ import annotations + +import pytest +import torch + +from mllm_kernel.cuda.jit.vocab_embedding import ( + assemble_deepstack_embedding, + embedding_lookup, + embedding_lookup_multimodal, + embedding_lookup_with_image, +) + + +def _make_lookup_inputs( + *, + num_tokens: int, + vocab_size: int, + hidden_size: int, + dtype: torch.dtype, + seed: int = 0, +): + """Build random (input_ids, embedding_table) for embedding_lookup. + + All token IDs are valid indices into embedding_table, so the result + should be bit-exact against torch.index_select. + """ + torch.manual_seed(seed) + device = "cuda" + input_ids = torch.randint( + 0, vocab_size, (num_tokens,), device=device, dtype=torch.int32 + ) + embedding_table = torch.randn( + vocab_size, hidden_size, device=device, dtype=dtype + ) + return input_ids, embedding_table + + +def _make_lookup_with_image_inputs( + *, + num_tokens: int, + vocab_size: int, + hidden_size: int, + image_token_len: int, + dtype: torch.dtype, + seed: int = 0, +): + """Build mixed text/image inputs for embedding_lookup_with_image. + + The first half of input_ids are text tokens in [0, vocab_size), + the second half are image tokens in [vocab_size, vocab_size + image_token_len). + """ + torch.manual_seed(seed) + device = "cuda" + n_text = num_tokens // 2 + n_image = num_tokens - n_text + text_ids = torch.randint( + 0, vocab_size, (n_text,), device=device, dtype=torch.int32 + ) + image_ids = torch.randint( + vocab_size, vocab_size + image_token_len, + (n_image,), device=device, dtype=torch.int32, + ) + input_ids = torch.cat([text_ids, image_ids]) + embedding_table = torch.randn( + vocab_size, hidden_size, device=device, dtype=dtype + ) + image_embeds = torch.randn( + image_token_len, hidden_size, device=device, dtype=dtype + ) + return input_ids, embedding_table, image_embeds + + +def _make_deepstack_inputs( + *, + num_tokens: int, + vocab_size: int, + hidden_size: int, + num_image_tokens: int, + dtype: torch.dtype, + seed: int = 0, +): + """Build image-only inputs for assemble_deepstack_embedding. + + All token IDs are in [vocab_size, vocab_size + num_image_tokens). + """ + torch.manual_seed(seed) + device = "cuda" + input_ids = torch.randint( + vocab_size, vocab_size + num_image_tokens, + (num_tokens,), device=device, dtype=torch.int32, + ) + deepstack_features = torch.randn( + num_image_tokens, hidden_size, device=device, dtype=dtype + ) + return input_ids, deepstack_features + + +def _make_multimodal_inputs( + *, + num_tokens: int, + vocab_size: int, + hidden_size: int, + image_token_len: int, + audio_token_len: int, + image_token_id: int, + audio_token_id: int, + dtype: torch.dtype, + seed: int = 0, +): + """Build mixed text/image/audio inputs for embedding_lookup_multimodal. + + Roughly one third each of text, image, and audio tokens, shuffled. + multimodal_indices are valid for each token type: image indices in + [0, image_token_len), audio indices in [0, audio_token_len). + """ + torch.manual_seed(seed) + device = "cuda" + + n_text = num_tokens // 3 + n_image = num_tokens // 3 + n_audio = num_tokens - n_text - n_image + + text_ids = torch.randint( + 0, vocab_size, (n_text,), device=device, dtype=torch.int32 + ) + image_ids = torch.full( + (n_image,), image_token_id, device=device, dtype=torch.int32 + ) + audio_ids = torch.full( + (n_audio,), audio_token_id, device=device, dtype=torch.int32 + ) + input_ids = torch.cat([text_ids, image_ids, audio_ids]) + + text_idx = torch.zeros(n_text, device=device, dtype=torch.int32) + image_idx = torch.randint( + 0, image_token_len, (n_image,), device=device, dtype=torch.int32 + ) + audio_idx = torch.randint( + 0, audio_token_len, (n_audio,), device=device, dtype=torch.int32 + ) + multimodal_indices = torch.cat([text_idx, image_idx, audio_idx]) + + perm = torch.randperm(num_tokens, device=device) + input_ids = input_ids[perm] + multimodal_indices = multimodal_indices[perm] + + embedding_table = torch.randn( + vocab_size, hidden_size, device=device, dtype=dtype + ) + image_embeds = torch.randn( + image_token_len, hidden_size, device=device, dtype=dtype + ) + audio_embeds = torch.randn( + audio_token_len, hidden_size, device=device, dtype=dtype + ) + return input_ids, embedding_table, multimodal_indices, image_embeds, audio_embeds + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_embedding_lookup_matches_torch(dtype: torch.dtype): + """embedding_lookup must produce bit-exact results vs torch.index_select.""" + num_tokens = 257 + vocab_size = 32000 + hidden_size = 1024 + + input_ids, embedding_table = _make_lookup_inputs( + num_tokens=num_tokens, + vocab_size=vocab_size, + hidden_size=hidden_size, + dtype=dtype, + seed=2026, + ) + + output_ref = torch.index_select(embedding_table, 0, input_ids) + + output = embedding_lookup(input_ids, embedding_table) + torch.cuda.synchronize() + + assert torch.equal(output, output_ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_embedding_lookup_out_of_range_tokens(): + """Out-of-range token IDs (negative or >= vocab_size) must produce zero vectors.""" + vocab_size = 100 + hidden_size = 64 + dtype = torch.float16 + + torch.manual_seed(2026) + device = "cuda" + embedding_table = torch.randn(vocab_size, hidden_size, device=device, dtype=dtype) + input_ids = torch.tensor( + [-1, vocab_size, vocab_size + 100], device=device, dtype=torch.int32 + ) + + output = embedding_lookup(input_ids, embedding_table) + torch.cuda.synchronize() + + expected = torch.zeros(3, hidden_size, dtype=dtype, device=device) + assert torch.equal(output, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_embedding_lookup_with_image_matches_reference(dtype: torch.dtype): + """embedding_lookup_with_image must match a naive per-token reference. + + Text tokens [0, vocab_size) come from embedding_table; image tokens + [vocab_size, vocab_size + image_token_len) come from image_embeds. + """ + num_tokens = 257 + vocab_size = 1000 + hidden_size = 1024 + image_token_len = 576 + + input_ids, embedding_table, image_embeds = _make_lookup_with_image_inputs( + num_tokens=num_tokens, + vocab_size=vocab_size, + hidden_size=hidden_size, + image_token_len=image_token_len, + dtype=dtype, + seed=2026, + ) + + input_ids_cpu = input_ids.cpu() + embedding_table_cpu = embedding_table.cpu() + image_embeds_cpu = image_embeds.cpu() + ref = torch.zeros(num_tokens, hidden_size, dtype=dtype) + for i in range(num_tokens): + tid = input_ids_cpu[i].item() + if tid >= vocab_size: + ref[i] = image_embeds_cpu[tid - vocab_size] + elif tid >= 0: + ref[i] = embedding_table_cpu[tid] + + output = embedding_lookup_with_image(input_ids, embedding_table, image_embeds) + torch.cuda.synchronize() + + assert torch.equal(output.cpu(), ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_embedding_lookup_with_image_out_of_range(): + """Image token indices beyond image_token_len must produce zero vectors.""" + vocab_size = 100 + hidden_size = 64 + image_token_len = 10 + dtype = torch.float16 + + torch.manual_seed(2026) + device = "cuda" + embedding_table = torch.randn(vocab_size, hidden_size, device=device, dtype=dtype) + image_embeds = torch.randn(image_token_len, hidden_size, device=device, dtype=dtype) + input_ids = torch.tensor( + [vocab_size + image_token_len, vocab_size + image_token_len + 100], + device=device, dtype=torch.int32, + ) + + output = embedding_lookup_with_image(input_ids, embedding_table, image_embeds) + torch.cuda.synchronize() + + expected = torch.zeros(2, hidden_size, dtype=dtype, device=device) + assert torch.equal(output, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_assemble_deepstack_embedding_matches_reference(dtype: torch.dtype): + """assemble_deepstack_embedding must extract image embeddings correctly. + + Token IDs >= vocab_size are looked up from deepstack_features at + index (token_id - vocab_size). Token IDs < vocab_size produce zeros. + """ + num_tokens = 257 + vocab_size = 1000 + hidden_size = 1024 + num_image_tokens = 576 + + input_ids, deepstack_features = _make_deepstack_inputs( + num_tokens=num_tokens, + vocab_size=vocab_size, + hidden_size=hidden_size, + num_image_tokens=num_image_tokens, + dtype=dtype, + seed=2026, + ) + + input_ids_cpu = input_ids.cpu() + features_cpu = deepstack_features.cpu() + ref = torch.zeros(num_tokens, hidden_size, dtype=dtype) + for i in range(num_tokens): + tid = input_ids_cpu[i].item() + if tid >= vocab_size: + ref[i] = features_cpu[tid - vocab_size] + + output = assemble_deepstack_embedding(input_ids, deepstack_features, vocab_size) + torch.cuda.synchronize() + + assert torch.equal(output.cpu(), ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_assemble_deepstack_text_tokens_produce_zeros(): + """All-text input must produce an all-zero output.""" + num_tokens = 32 + vocab_size = 100 + hidden_size = 64 + num_image_tokens = 10 + dtype = torch.float16 + + torch.manual_seed(2026) + device = "cuda" + input_ids = torch.randint( + 0, vocab_size, (num_tokens,), device=device, dtype=torch.int32 + ) + deepstack_features = torch.randn( + num_image_tokens, hidden_size, device=device, dtype=dtype + ) + + output = assemble_deepstack_embedding(input_ids, deepstack_features, vocab_size) + torch.cuda.synchronize() + + expected = torch.zeros(num_tokens, hidden_size, dtype=dtype, device=device) + assert torch.equal(output, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_embedding_lookup_multimodal_matches_reference(dtype: torch.dtype): + """embedding_lookup_multimodal must match a naive per-token reference. + + Token IDs matching image_token_id use multimodal_indices into image_embeds. + Token IDs matching audio_token_id use multimodal_indices into audio_embeds. + Other valid token IDs come from embedding_table. + """ + num_tokens = 257 + vocab_size = 1000 + hidden_size = 1024 + image_token_len = 50 + audio_token_len = 30 + image_token_id = 32000 + audio_token_id = 32001 + + ( + input_ids, + embedding_table, + multimodal_indices, + image_embeds, + audio_embeds, + ) = _make_multimodal_inputs( + num_tokens=num_tokens, + vocab_size=vocab_size, + hidden_size=hidden_size, + image_token_len=image_token_len, + audio_token_len=audio_token_len, + image_token_id=image_token_id, + audio_token_id=audio_token_id, + dtype=dtype, + seed=2026, + ) + + input_ids_cpu = input_ids.cpu() + mm_indices_cpu = multimodal_indices.cpu() + embedding_table_cpu = embedding_table.cpu() + image_embeds_cpu = image_embeds.cpu() + audio_embeds_cpu = audio_embeds.cpu() + ref = torch.zeros(num_tokens, hidden_size, dtype=dtype) + for i in range(num_tokens): + tid = input_ids_cpu[i].item() + idx = mm_indices_cpu[i].item() + if tid == image_token_id and 0 <= idx < image_token_len: + ref[i] = image_embeds_cpu[idx] + elif tid == audio_token_id and 0 <= idx < audio_token_len: + ref[i] = audio_embeds_cpu[idx] + elif 0 <= tid < vocab_size: + ref[i] = embedding_table_cpu[tid] + + output = embedding_lookup_multimodal( + input_ids, embedding_table, multimodal_indices, + image_embeds, audio_embeds, image_token_id, audio_token_id, + ) + torch.cuda.synchronize() + + assert torch.equal(output.cpu(), ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_embedding_lookup_multimodal_none_embeds(): + """Passing None for both image and audio embeds must fall back to text-only lookup.""" + num_tokens = 64 + vocab_size = 1000 + hidden_size = 128 + image_token_id = 32000 + audio_token_id = 32001 + dtype = torch.float16 + + torch.manual_seed(2026) + device = "cuda" + input_ids = torch.randint( + 0, vocab_size, (num_tokens,), device=device, dtype=torch.int32 + ) + embedding_table = torch.randn( + vocab_size, hidden_size, device=device, dtype=dtype + ) + multimodal_indices = torch.zeros( + num_tokens, device=device, dtype=torch.int32 + ) + + output = embedding_lookup_multimodal( + input_ids, embedding_table, multimodal_indices, + image_embeds=None, audio_embeds=None, + image_token_id=image_token_id, audio_token_id=audio_token_id, + ) + torch.cuda.synchronize() + + output_ref = torch.index_select(embedding_table, 0, input_ids) + assert torch.equal(output, output_ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_embedding_lookup_multimodal_image_only(): + """Multimodal lookup with only image embeds (audio_embeds=None).""" + num_tokens = 64 + vocab_size = 1000 + hidden_size = 128 + image_token_len = 50 + image_token_id = 32000 + audio_token_id = 32001 + dtype = torch.float16 + + torch.manual_seed(2026) + device = "cuda" + + n_text = num_tokens // 2 + n_image = num_tokens - n_text + text_ids = torch.randint( + 0, vocab_size, (n_text,), device=device, dtype=torch.int32 + ) + image_ids = torch.full( + (n_image,), image_token_id, device=device, dtype=torch.int32 + ) + input_ids = torch.cat([text_ids, image_ids]) + + text_idx = torch.zeros(n_text, device=device, dtype=torch.int32) + image_idx = torch.randint( + 0, image_token_len, (n_image,), device=device, dtype=torch.int32 + ) + multimodal_indices = torch.cat([text_idx, image_idx]) + + embedding_table = torch.randn( + vocab_size, hidden_size, device=device, dtype=dtype + ) + image_embeds = torch.randn( + image_token_len, hidden_size, device=device, dtype=dtype + ) + + output = embedding_lookup_multimodal( + input_ids, embedding_table, multimodal_indices, + image_embeds=image_embeds, audio_embeds=None, + image_token_id=image_token_id, audio_token_id=audio_token_id, + ) + torch.cuda.synchronize() + + input_ids_cpu = input_ids.cpu() + mm_indices_cpu = multimodal_indices.cpu() + embedding_table_cpu = embedding_table.cpu() + image_embeds_cpu = image_embeds.cpu() + ref = torch.zeros(num_tokens, hidden_size, dtype=dtype) + for i in range(num_tokens): + tid = input_ids_cpu[i].item() + idx = mm_indices_cpu[i].item() + if tid == image_token_id and 0 <= idx < image_token_len: + ref[i] = image_embeds_cpu[idx] + elif 0 <= tid < vocab_size: + ref[i] = embedding_table_cpu[tid] + + assert torch.equal(output.cpu(), ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_embedding_lookup_single_token(): + """Single token — exercises the minimal-work path.""" + dtype = torch.float16 + input_ids, embedding_table = _make_lookup_inputs( + num_tokens=1, vocab_size=100, hidden_size=64, dtype=dtype, seed=2026, + ) + + output = embedding_lookup(input_ids, embedding_table) + torch.cuda.synchronize() + + expected = embedding_table[input_ids[0].item()].unsqueeze(0) + assert torch.equal(output, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_embedding_lookup_all_same_token(): + """All tokens identical — verifies broadcast-like correctness.""" + num_tokens = 64 + vocab_size = 1000 + hidden_size = 128 + dtype = torch.float16 + + torch.manual_seed(2026) + device = "cuda" + embedding_table = torch.randn( + vocab_size, hidden_size, device=device, dtype=dtype + ) + input_ids = torch.full( + (num_tokens,), 42, device=device, dtype=torch.int32 + ) + + output = embedding_lookup(input_ids, embedding_table) + torch.cuda.synchronize() + + expected = embedding_table[42].unsqueeze(0).expand(num_tokens, -1) + assert torch.equal(output, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_embedding_lookup_large_hidden_size(): + """Large hidden_size (4096) — verifies vectorised copy handles long rows.""" + input_ids, embedding_table = _make_lookup_inputs( + num_tokens=32, vocab_size=1000, hidden_size=4096, + dtype=torch.float16, seed=2026, + ) + + output = embedding_lookup(input_ids, embedding_table) + torch.cuda.synchronize() + + expected = torch.index_select(embedding_table, 0, input_ids) + assert torch.equal(output, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_embedding_lookup_large_batch(): + """Large batch (2048 tokens) — stress test for grid dimension.""" + input_ids, embedding_table = _make_lookup_inputs( + num_tokens=2048, vocab_size=32000, hidden_size=256, + dtype=torch.float16, seed=2026, + ) + + output = embedding_lookup(input_ids, embedding_table) + torch.cuda.synchronize() + + expected = torch.index_select(embedding_table, 0, input_ids) + assert torch.equal(output, expected) \ No newline at end of file From 41d4f443bb1f5122f381c7613f2f43381610541c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E6=8C=AF=E5=8D=8E?= <1823301126@qq.com> Date: Tue, 3 Mar 2026 15:04:00 +0800 Subject: [PATCH 2/2] feat(kernel): add vocab embedding CUDA kernels with benchmarks Benchmark results (num_tokens=1024): - embedding_lookup: 4.03x speedup - embedding_lookup_with_image: 7.99x speedup - assemble_deepstack_embedding: 8.74x speedup - embedding_lookup_multimodal: 9.89x speedup --- .../benchmarks/bench_vocab_embedding.py | 569 ++++++++++++++++++ 1 file changed, 569 insertions(+) create mode 100755 mllm-kernel/benchmarks/bench_vocab_embedding.py diff --git a/mllm-kernel/benchmarks/bench_vocab_embedding.py b/mllm-kernel/benchmarks/bench_vocab_embedding.py new file mode 100755 index 00000000..a6f8cc91 --- /dev/null +++ b/mllm-kernel/benchmarks/bench_vocab_embedding.py @@ -0,0 +1,569 @@ +"""Benchmark vocab_embedding ops vs torch baseline with torch.profiler. + +Example: + python benchmarks/bench_vocab_embedding.py --op all --warmup 20 --iters 200 + python benchmarks/bench_vocab_embedding.py --op embedding_lookup --num-tokens 1024 +""" + +from __future__ import annotations + +import argparse + +import torch +from torch.profiler import ProfilerActivity, profile + +from mllm_kernel.cuda.jit.vocab_embedding import ( + assemble_deepstack_embedding, + embedding_lookup, + embedding_lookup_multimodal, + embedding_lookup_with_image, +) + +ALL_OPS = [ + "embedding_lookup", + "embedding_lookup_with_image", + "assemble_deepstack_embedding", + "embedding_lookup_multimodal", +] + + +def _run_embedding_lookup_once(input_ids, embedding_table): + embedding_lookup(input_ids, embedding_table) + + +def _run_torch_embedding_lookup_once(input_ids, embedding_table, output): + output[:] = embedding_table[input_ids.long()] + + +def _run_embedding_lookup_with_image_once(input_ids, embedding_table, image_embeds): + embedding_lookup_with_image(input_ids, embedding_table, image_embeds) + + +def _run_torch_embedding_lookup_with_image_once( + input_ids, embedding_table, image_embeds, output +): + vocab_size = embedding_table.shape[0] + ids_long = input_ids.long() + text_mask = input_ids < vocab_size + img_mask = ~text_mask + output[text_mask] = embedding_table[ids_long[text_mask]] + output[img_mask] = image_embeds[ids_long[img_mask] - vocab_size] + + +def _run_assemble_deepstack_once(input_ids, deepstack_features, vocab_size): + assemble_deepstack_embedding(input_ids, deepstack_features, vocab_size) + + +def _run_torch_assemble_deepstack_once( + input_ids, deepstack_features, vocab_size, output +): + output.zero_() + img_mask = input_ids >= vocab_size + output[img_mask] = deepstack_features[input_ids[img_mask].long() - vocab_size] + + +def _run_multimodal_once( + input_ids, + embedding_table, + multimodal_indices, + image_embeds, + audio_embeds, + image_token_id, + audio_token_id, +): + embedding_lookup_multimodal( + input_ids, + embedding_table, + multimodal_indices, + image_embeds, + audio_embeds, + image_token_id, + audio_token_id, + ) + + +def _run_torch_multimodal_once( + input_ids, + embedding_table, + multimodal_indices, + image_embeds, + audio_embeds, + image_token_id, + audio_token_id, + output, +): + vocab_size = embedding_table.shape[0] + ids_long = input_ids.long() + output.zero_() + text_mask = (ids_long >= 0) & (ids_long < vocab_size) + output[text_mask] = embedding_table[ids_long[text_mask]] + if image_embeds is not None and image_embeds.shape[0] > 0 and image_token_id >= 0: + img_mask = input_ids == image_token_id + output[img_mask] = image_embeds[multimodal_indices[img_mask].long()] + if audio_embeds is not None and audio_embeds.shape[0] > 0 and audio_token_id >= 0: + aud_mask = input_ids == audio_token_id + output[aud_mask] = audio_embeds[multimodal_indices[aud_mask].long()] + + +def _make_lookup_inputs( + *, num_tokens, vocab_size, hidden_size, dtype, device, seed +): + torch.manual_seed(seed) + input_ids = torch.randint( + 0, vocab_size, (num_tokens,), device=device, dtype=torch.int32 + ) + embedding_table = torch.randn(vocab_size, hidden_size, device=device, dtype=dtype) + return input_ids, embedding_table + + +def _make_lookup_with_image_inputs( + *, num_tokens, vocab_size, hidden_size, image_token_len, dtype, device, seed +): + torch.manual_seed(seed) + n_text = num_tokens // 2 + n_image = num_tokens - n_text + text_ids = torch.randint( + 0, vocab_size, (n_text,), device=device, dtype=torch.int32 + ) + image_ids = torch.randint( + vocab_size, + vocab_size + image_token_len, + (n_image,), + device=device, + dtype=torch.int32, + ) + input_ids = torch.cat([text_ids, image_ids]) + embedding_table = torch.randn(vocab_size, hidden_size, device=device, dtype=dtype) + image_embeds = torch.randn( + image_token_len, hidden_size, device=device, dtype=dtype + ) + return input_ids, embedding_table, image_embeds + + +def _make_deepstack_inputs( + *, num_tokens, vocab_size, hidden_size, num_image_tokens, dtype, device, seed +): + torch.manual_seed(seed) + input_ids = torch.randint( + vocab_size, + vocab_size + num_image_tokens, + (num_tokens,), + device=device, + dtype=torch.int32, + ) + deepstack_features = torch.randn( + num_image_tokens, hidden_size, device=device, dtype=dtype + ) + return input_ids, deepstack_features + + +def _make_multimodal_inputs( + *, + num_tokens, + vocab_size, + hidden_size, + image_token_len, + audio_token_len, + image_token_id, + audio_token_id, + dtype, + device, + seed, +): + torch.manual_seed(seed) + + n_text = num_tokens // 3 + n_image = num_tokens // 3 + n_audio = num_tokens - n_text - n_image + + text_ids = torch.randint( + 0, vocab_size, (n_text,), device=device, dtype=torch.int32 + ) + image_ids = torch.full( + (n_image,), image_token_id, device=device, dtype=torch.int32 + ) + audio_ids = torch.full( + (n_audio,), audio_token_id, device=device, dtype=torch.int32 + ) + input_ids = torch.cat([text_ids, image_ids, audio_ids]) + + text_idx = torch.zeros(n_text, device=device, dtype=torch.int32) + image_idx = torch.randint( + 0, image_token_len, (n_image,), device=device, dtype=torch.int32 + ) + audio_idx = torch.randint( + 0, audio_token_len, (n_audio,), device=device, dtype=torch.int32 + ) + multimodal_indices = torch.cat([text_idx, image_idx, audio_idx]) + + perm = torch.randperm(num_tokens, device=device) + input_ids = input_ids[perm] + multimodal_indices = multimodal_indices[perm] + + embedding_table = torch.randn(vocab_size, hidden_size, device=device, dtype=dtype) + image_embeds = torch.randn( + image_token_len, hidden_size, device=device, dtype=dtype + ) + audio_embeds = torch.randn( + audio_token_len, hidden_size, device=device, dtype=dtype + ) + return input_ids, embedding_table, multimodal_indices, image_embeds, audio_embeds + + +def _profile_path( + name: str, fn, *, warmup: int, iters: int, row_limit: int, trace_path: str | None +): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=False, + profile_memory=False, + with_stack=False, + ) as prof: + for _ in range(iters): + fn() + torch.cuda.synchronize() + + events = prof.key_averages() + # torch profiler times are in microseconds. + # PyTorch versions vary between *cuda* and *device* naming. + time_attr = ( + "self_cuda_time_total" + if events and hasattr(events[0], "self_cuda_time_total") + else "self_device_time_total" + ) + sort_key = ( + "self_cuda_time_total" + if time_attr == "self_cuda_time_total" + else "self_device_time_total" + ) + total_self_device_us = sum(float(getattr(evt, time_attr, 0.0)) for evt in events) + avg_self_device_us = total_self_device_us / max(iters, 1) + + print(f"\n=== {name} ===") + print( + prof.key_averages().table( + sort_by=sort_key, + row_limit=row_limit, + ) + ) + print(f"{name} total self device time: {total_self_device_us:.2f} us") + print(f"{name} avg self device time/iter: {avg_self_device_us:.2f} us") + + if trace_path: + prof.export_chrome_trace(trace_path) + print(f"{name} trace exported: {trace_path}") + + return avg_self_device_us + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark vocab_embedding ops vs torch baseline using torch.profiler" + ) + parser.add_argument( + "--op", + type=str, + default="all", + choices=["all"] + ALL_OPS, + ) + parser.add_argument("--num-tokens", type=int, default=1024) + parser.add_argument("--vocab-size", type=int, default=32000) + parser.add_argument("--hidden-size", type=int, default=4096) + parser.add_argument("--image-token-len", type=int, default=576) + parser.add_argument("--audio-token-len", type=int, default=256) + parser.add_argument("--image-token-id", type=int, default=32000) + parser.add_argument("--audio-token-id", type=int, default=32001) + parser.add_argument( + "--dtype", + type=str, + default="float16", + choices=["float16", "bfloat16"], + ) + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument("--iters", type=int, default=200) + parser.add_argument("--row-limit", type=int, default=20) + parser.add_argument("--export-trace-dir", type=str, default="") + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this benchmark") + + torch.manual_seed(args.seed) + device = torch.device("cuda") + dtype = getattr(torch, args.dtype) + trace_dir = args.export_trace_dir.strip() + + ops = ALL_OPS if args.op == "all" else [args.op] + + for op in ops: + if op == "embedding_lookup": + input_ids, embedding_table = _make_lookup_inputs( + num_tokens=args.num_tokens, + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + dtype=dtype, + device=device, + seed=args.seed, + ) + output = torch.empty( + args.num_tokens, args.hidden_size, dtype=dtype, device=device + ) + + print("=== embedding_lookup profiler benchmark ===") + print( + f"shape: num_tokens={args.num_tokens}, vocab_size={args.vocab_size}, " + f"hidden_size={args.hidden_size}, dtype={dtype}" + ) + print( + f"warmup={args.warmup}, iters={args.iters}, row_limit={args.row_limit}" + ) + + kernel_trace = ( + f"{trace_dir}/embedding_lookup_kernel_trace.json" + if trace_dir + else None + ) + torch_trace = ( + f"{trace_dir}/embedding_lookup_torch_trace.json" + if trace_dir + else None + ) + + kernel_avg_us = _profile_path( + "embedding_lookup", + lambda: _run_embedding_lookup_once(input_ids, embedding_table), + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=kernel_trace, + ) + + torch_avg_us = _profile_path( + "torch_embedding_lookup", + lambda: _run_torch_embedding_lookup_once( + input_ids, embedding_table, output + ), + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=torch_trace, + ) + + speedup = torch_avg_us / max(kernel_avg_us, 1e-12) + print(f"\nSpeedup: {speedup:.3f}x") + + elif op == "embedding_lookup_with_image": + input_ids, embedding_table, image_embeds = _make_lookup_with_image_inputs( + num_tokens=args.num_tokens, + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + image_token_len=args.image_token_len, + dtype=dtype, + device=device, + seed=args.seed, + ) + output = torch.empty( + args.num_tokens, args.hidden_size, dtype=dtype, device=device + ) + + print("=== embedding_lookup_with_image profiler benchmark ===") + print( + f"shape: num_tokens={args.num_tokens}, vocab_size={args.vocab_size}, " + f"hidden_size={args.hidden_size}, image_token_len={args.image_token_len}, " + f"dtype={dtype}" + ) + print( + f"warmup={args.warmup}, iters={args.iters}, row_limit={args.row_limit}" + ) + + kernel_trace = ( + f"{trace_dir}/embedding_lookup_with_image_kernel_trace.json" + if trace_dir + else None + ) + torch_trace = ( + f"{trace_dir}/embedding_lookup_with_image_torch_trace.json" + if trace_dir + else None + ) + + kernel_avg_us = _profile_path( + "embedding_lookup_with_image", + lambda: _run_embedding_lookup_with_image_once( + input_ids, embedding_table, image_embeds + ), + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=kernel_trace, + ) + + torch_avg_us = _profile_path( + "torch_embedding_lookup_with_image", + lambda: _run_torch_embedding_lookup_with_image_once( + input_ids, embedding_table, image_embeds, output + ), + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=torch_trace, + ) + + speedup = torch_avg_us / max(kernel_avg_us, 1e-12) + print(f"\nSpeedup: {speedup:.3f}x") + + elif op == "assemble_deepstack_embedding": + input_ids, deepstack_features = _make_deepstack_inputs( + num_tokens=args.num_tokens, + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + num_image_tokens=args.image_token_len, + dtype=dtype, + device=device, + seed=args.seed, + ) + output = torch.empty( + args.num_tokens, args.hidden_size, dtype=dtype, device=device + ) + + print("=== assemble_deepstack_embedding profiler benchmark ===") + print( + f"shape: num_tokens={args.num_tokens}, vocab_size={args.vocab_size}, " + f"hidden_size={args.hidden_size}, num_image_tokens={args.image_token_len}, " + f"dtype={dtype}" + ) + print( + f"warmup={args.warmup}, iters={args.iters}, row_limit={args.row_limit}" + ) + + kernel_trace = ( + f"{trace_dir}/assemble_deepstack_kernel_trace.json" + if trace_dir + else None + ) + torch_trace = ( + f"{trace_dir}/assemble_deepstack_torch_trace.json" + if trace_dir + else None + ) + + vocab_size = args.vocab_size + + kernel_avg_us = _profile_path( + "assemble_deepstack_embedding", + lambda: _run_assemble_deepstack_once( + input_ids, deepstack_features, vocab_size + ), + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=kernel_trace, + ) + + torch_avg_us = _profile_path( + "torch_assemble_deepstack", + lambda: _run_torch_assemble_deepstack_once( + input_ids, deepstack_features, vocab_size, output + ), + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=torch_trace, + ) + + speedup = torch_avg_us / max(kernel_avg_us, 1e-12) + print(f"\nSpeedup: {speedup:.3f}x") + + elif op == "embedding_lookup_multimodal": + ( + input_ids, + embedding_table, + multimodal_indices, + image_embeds, + audio_embeds, + ) = _make_multimodal_inputs( + num_tokens=args.num_tokens, + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + image_token_len=args.image_token_len, + audio_token_len=args.audio_token_len, + image_token_id=args.image_token_id, + audio_token_id=args.audio_token_id, + dtype=dtype, + device=device, + seed=args.seed, + ) + output = torch.empty( + args.num_tokens, args.hidden_size, dtype=dtype, device=device + ) + + print("=== embedding_lookup_multimodal profiler benchmark ===") + print( + f"shape: num_tokens={args.num_tokens}, vocab_size={args.vocab_size}, " + f"hidden_size={args.hidden_size}, image_token_len={args.image_token_len}, " + f"audio_token_len={args.audio_token_len}, dtype={dtype}" + ) + print( + f"image_token_id={args.image_token_id}, " + f"audio_token_id={args.audio_token_id}" + ) + print( + f"warmup={args.warmup}, iters={args.iters}, row_limit={args.row_limit}" + ) + + image_token_id = args.image_token_id + audio_token_id = args.audio_token_id + + kernel_trace = ( + f"{trace_dir}/multimodal_kernel_trace.json" if trace_dir else None + ) + torch_trace = ( + f"{trace_dir}/multimodal_torch_trace.json" if trace_dir else None + ) + + kernel_avg_us = _profile_path( + "embedding_lookup_multimodal", + lambda: _run_multimodal_once( + input_ids, + embedding_table, + multimodal_indices, + image_embeds, + audio_embeds, + image_token_id, + audio_token_id, + ), + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=kernel_trace, + ) + + torch_avg_us = _profile_path( + "torch_multimodal", + lambda: _run_torch_multimodal_once( + input_ids, + embedding_table, + multimodal_indices, + image_embeds, + audio_embeds, + image_token_id, + audio_token_id, + output, + ), + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=torch_trace, + ) + + speedup = torch_avg_us / max(kernel_avg_us, 1e-12) + print(f"\nSpeedup: {speedup:.3f}x") + + +if __name__ == "__main__": + main()