Skip to content

feat(kernel): add vocab embedding CUDA kernels#645

Open
huangzhenhua111 wants to merge 2 commits intoUbiquitousLearning:mainfrom
huangzhenhua111:feat/vocab-embedding
Open

feat(kernel): add vocab embedding CUDA kernels#645
huangzhenhua111 wants to merge 2 commits intoUbiquitousLearning:mainfrom
huangzhenhua111:feat/vocab-embedding

Conversation

@huangzhenhua111
Copy link

@huangzhenhua111 huangzhenhua111 commented Mar 2, 2026

Add four vectorized CUDA embedding kernels:

  • embedding_lookup: standard token embedding(4.03x speedup @ 1024)
  • embedding_lookup_with_image: token + image embedding fusion (test input:1/2 text + 1/2 image , in order) (7.99x speedup @ 1024)
  • assemble_deepstack_embedding: extract image-only embeddings (8.74x speedup @ 1024)
  • embedding_lookup_multimodal: text + image + audio embedding (test input:1/3 text+1/3 image+1/3 audio ,randomly)(9.89x speedup @ 1024)

All 17 tests passed.

Average speedup: 7.66x (1024), 5.79x (4096), 5.99x (8192).

Add four vectorized CUDA embedding kernels:
- embedding_lookup: standard token embedding
- embedding_lookup_with_image: token + image embedding fusion
- assemble_deepstack_embedding: extract image-only embeddings
- embedding_lookup_multimodal: text + image + audio embedding

All 17 tests passed.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 2, 2026

📝 Walkthrough

Walkthrough

Introduces a complete CUDA-based vocabulary embedding module with four kernel variants for standard lookup, image-fused lookup, deepstack assembly, and multimodal (text/image/audio) embedding operations, including Python wrappers and comprehensive test coverage.

Changes

Cohort / File(s) Summary
CUDA Kernel Implementation
mllm_kernel/cuda/csrc/vocab_embedding.cuh
Implements four CUDA kernels with parameter structs: embedding_lookup_kernel, embedding_lookup_with_image_kernel, assemble_deepstack_embedding_kernel, embedding_lookup_multimodal_kernel. Includes device helpers for warp-level row operations and host-side launch wrappers with input validation, dtype/device checks, and stride computation.
Python JIT Wrapper
mllm_kernel/cuda/jit/vocab_embedding.py
Provides four public Python functions (embedding_lookup, embedding_lookup_with_image, assemble_deepstack_embedding, embedding_lookup_multimodal) that wrap CUDA kernels. Handles tensor contiguity, dtype normalization, output allocation, and kernel invocation via TVM FFI interface.
Module Exports
mllm_kernel/cuda/jit/__init__.py
Expands public API to export the four new embedding functions from vocab_embedding module while retaining existing exports.
Test Suite
tests/test_vocab_embedding.py
Comprehensive test coverage for all four embedding operations across multiple dtypes (float16, bfloat16), out-of-range handling, image/audio fusion scenarios, and edge cases with CUDA availability gating.

Sequence Diagram(s)

sequenceDiagram
    participant User as Python Caller
    participant Wrapper as JIT Wrapper<br/>(vocab_embedding.py)
    participant TVM as TVM FFI<br/>Interface
    participant CUDA as CUDA Kernel<br/>(Device)
    
    User->>Wrapper: embedding_lookup(input_ids,<br/>embedding_table)
    activate Wrapper
    Wrapper->>Wrapper: Validate inputs (dtype, device,<br/>contiguity)
    Wrapper->>Wrapper: Allocate output tensor
    Wrapper->>Wrapper: Compute stride_bytes from<br/>hidden_size
    Wrapper->>TVM: Invoke _embedding_lookup_kernel<br/>(via LaunchKernel)
    activate TVM
    TVM->>CUDA: Launch embedding_lookup_kernel<br/>with grid/block config
    activate CUDA
    CUDA->>CUDA: Each thread looks up token<br/>embedding from table
    CUDA->>CUDA: Fill zeros for<br/>out-of-range tokens
    CUDA-->>TVM: Kernel complete
    deactivate CUDA
    TVM-->>Wrapper: Return
    deactivate TVM
    Wrapper-->>User: Return output tensor
    deactivate Wrapper
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Poem

🐰 Four kernels bloom in CUDA's garden bright,
Text, images, and audio take flight,
From tokens deep to embeddings refined,
A multimodal feast for any mind!
Hop along, dear reviewer, all checks align! 🌟

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Description check ⚠️ Warning The pull request description is minimal and lacks required sections from the template. Add structured sections including motivation, implementation details, and testing information following the repository contribution guidelines template.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: adding CUDA kernels for vocabulary embedding operations.
Docstring Coverage ✅ Passed Docstring coverage is 84.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Benchmark results (num_tokens=1024):
- embedding_lookup: 4.03x speedup
- embedding_lookup_with_image: 7.99x speedup
- assemble_deepstack_embedding: 8.74x speedup
- embedding_lookup_multimodal: 9.89x speedup
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
mllm-kernel/benchmarks/bench_vocab_embedding.py (3)

269-272: Use unpacking for cleaner list construction.

Static analysis suggests using list unpacking instead of concatenation.

🔧 Proposed fix
         "--op",
         type=str,
         default="all",
-        choices=["all"] + ALL_OPS,
+        choices=["all", *ALL_OPS],
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@mllm-kernel/benchmarks/bench_vocab_embedding.py` around lines 269 - 272, The
choices list for the argument uses concatenation ["all"] + ALL_OPS; replace that
with list unpacking to be cleaner and idiomatic by using ["all", *ALL_OPS] where
the argument is defined (the choices= parameter in the add_argument call that
references ALL_OPS).

1-6: Add shebang for executable script.

Static analysis indicates the file is executable but lacks a shebang line.

🔧 Proposed fix
+#!/usr/bin/env python3
 """Benchmark vocab_embedding ops vs torch baseline with torch.profiler.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@mllm-kernel/benchmarks/bench_vocab_embedding.py` around lines 1 - 6, This
script file (bench_vocab_embedding.py) is missing a shebang; make the file
executable by adding a Unix shebang line (e.g., for python3) at the very top of
the file before the module docstring so the script can be run directly from the
shell; ensure the shebang is the first line and preserved when committing.

337-356: Lambda closures capture loop variables by reference.

Static analysis (B023) flags that lambdas capture input_ids, embedding_table, output, etc. by reference rather than by value. While this works correctly in the current code because each lambda is consumed immediately within _profile_path before the loop continues, it's a pattern that can cause subtle bugs if the code is refactored (e.g., if lambdas were stored and executed later).

Consider using default arguments to capture values explicitly. This pattern repeats throughout the benchmark loop (lines 398, 409, 460, 471, 532, 549).

🔧 Proposed fix (example for embedding_lookup)
             kernel_avg_us = _profile_path(
                 "embedding_lookup",
-                lambda: _run_embedding_lookup_once(input_ids, embedding_table),
+                lambda ids=input_ids, table=embedding_table: _run_embedding_lookup_once(ids, table),
                 warmup=args.warmup,
                 iters=args.iters,
                 row_limit=args.row_limit,
                 trace_path=kernel_trace,
             )

             torch_avg_us = _profile_path(
                 "torch_embedding_lookup",
-                lambda: _run_torch_embedding_lookup_once(
-                    input_ids, embedding_table, output
-                ),
+                lambda ids=input_ids, table=embedding_table, out=output: _run_torch_embedding_lookup_once(
+                    ids, table, out
+                ),
                 warmup=args.warmup,
                 iters=args.iters,
                 row_limit=args.row_limit,
                 trace_path=torch_trace,
             )

Alternatively, use functools.partial for cleaner syntax:

from functools import partial
# ...
lambda: _run_embedding_lookup_once(input_ids, embedding_table)
# becomes:
partial(_run_embedding_lookup_once, input_ids, embedding_table)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@mllm-kernel/benchmarks/bench_vocab_embedding.py` around lines 337 - 356, The
lambdas passed to _profile_path capture loop variables (e.g., input_ids,
embedding_table, output) by reference, which risks subtle bugs if refactored;
update the call sites that pass lambda: _run_embedding_lookup_once(input_ids,
embedding_table) and lambda: _run_torch_embedding_lookup_once(input_ids,
embedding_table, output) to capture values explicitly (either by using default
args in the lambda, e.g. lambda input_ids=input_ids,
embedding_table=embedding_table: _run_embedding_lookup_once(input_ids,
embedding_table), or by using functools.partial to bind arguments:
partial(_run_embedding_lookup_once, input_ids, embedding_table)); apply the same
fix to the other similar lambda usages in the benchmark loop that call _run_*
helper functions so each closure binds values at creation time.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@mllm-kernel/benchmarks/bench_vocab_embedding.py`:
- Around line 269-272: The choices list for the argument uses concatenation
["all"] + ALL_OPS; replace that with list unpacking to be cleaner and idiomatic
by using ["all", *ALL_OPS] where the argument is defined (the choices= parameter
in the add_argument call that references ALL_OPS).
- Around line 1-6: This script file (bench_vocab_embedding.py) is missing a
shebang; make the file executable by adding a Unix shebang line (e.g., for
python3) at the very top of the file before the module docstring so the script
can be run directly from the shell; ensure the shebang is the first line and
preserved when committing.
- Around line 337-356: The lambdas passed to _profile_path capture loop
variables (e.g., input_ids, embedding_table, output) by reference, which risks
subtle bugs if refactored; update the call sites that pass lambda:
_run_embedding_lookup_once(input_ids, embedding_table) and lambda:
_run_torch_embedding_lookup_once(input_ids, embedding_table, output) to capture
values explicitly (either by using default args in the lambda, e.g. lambda
input_ids=input_ids, embedding_table=embedding_table:
_run_embedding_lookup_once(input_ids, embedding_table), or by using
functools.partial to bind arguments: partial(_run_embedding_lookup_once,
input_ids, embedding_table)); apply the same fix to the other similar lambda
usages in the benchmark loop that call _run_* helper functions so each closure
binds values at creation time.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 945af99 and 41d4f44.

📒 Files selected for processing (1)
  • mllm-kernel/benchmarks/bench_vocab_embedding.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant