feat(kernel): add vocab embedding CUDA kernels#645
feat(kernel): add vocab embedding CUDA kernels#645huangzhenhua111 wants to merge 2 commits intoUbiquitousLearning:mainfrom
Conversation
Add four vectorized CUDA embedding kernels: - embedding_lookup: standard token embedding - embedding_lookup_with_image: token + image embedding fusion - assemble_deepstack_embedding: extract image-only embeddings - embedding_lookup_multimodal: text + image + audio embedding All 17 tests passed.
📝 WalkthroughWalkthroughIntroduces a complete CUDA-based vocabulary embedding module with four kernel variants for standard lookup, image-fused lookup, deepstack assembly, and multimodal (text/image/audio) embedding operations, including Python wrappers and comprehensive test coverage. Changes
Sequence Diagram(s)sequenceDiagram
participant User as Python Caller
participant Wrapper as JIT Wrapper<br/>(vocab_embedding.py)
participant TVM as TVM FFI<br/>Interface
participant CUDA as CUDA Kernel<br/>(Device)
User->>Wrapper: embedding_lookup(input_ids,<br/>embedding_table)
activate Wrapper
Wrapper->>Wrapper: Validate inputs (dtype, device,<br/>contiguity)
Wrapper->>Wrapper: Allocate output tensor
Wrapper->>Wrapper: Compute stride_bytes from<br/>hidden_size
Wrapper->>TVM: Invoke _embedding_lookup_kernel<br/>(via LaunchKernel)
activate TVM
TVM->>CUDA: Launch embedding_lookup_kernel<br/>with grid/block config
activate CUDA
CUDA->>CUDA: Each thread looks up token<br/>embedding from table
CUDA->>CUDA: Fill zeros for<br/>out-of-range tokens
CUDA-->>TVM: Kernel complete
deactivate CUDA
TVM-->>Wrapper: Return
deactivate TVM
Wrapper-->>User: Return output tensor
deactivate Wrapper
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Benchmark results (num_tokens=1024): - embedding_lookup: 4.03x speedup - embedding_lookup_with_image: 7.99x speedup - assemble_deepstack_embedding: 8.74x speedup - embedding_lookup_multimodal: 9.89x speedup
There was a problem hiding this comment.
🧹 Nitpick comments (3)
mllm-kernel/benchmarks/bench_vocab_embedding.py (3)
269-272: Use unpacking for cleaner list construction.Static analysis suggests using list unpacking instead of concatenation.
🔧 Proposed fix
"--op", type=str, default="all", - choices=["all"] + ALL_OPS, + choices=["all", *ALL_OPS], )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mllm-kernel/benchmarks/bench_vocab_embedding.py` around lines 269 - 272, The choices list for the argument uses concatenation ["all"] + ALL_OPS; replace that with list unpacking to be cleaner and idiomatic by using ["all", *ALL_OPS] where the argument is defined (the choices= parameter in the add_argument call that references ALL_OPS).
1-6: Add shebang for executable script.Static analysis indicates the file is executable but lacks a shebang line.
🔧 Proposed fix
+#!/usr/bin/env python3 """Benchmark vocab_embedding ops vs torch baseline with torch.profiler.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mllm-kernel/benchmarks/bench_vocab_embedding.py` around lines 1 - 6, This script file (bench_vocab_embedding.py) is missing a shebang; make the file executable by adding a Unix shebang line (e.g., for python3) at the very top of the file before the module docstring so the script can be run directly from the shell; ensure the shebang is the first line and preserved when committing.
337-356: Lambda closures capture loop variables by reference.Static analysis (B023) flags that lambdas capture
input_ids,embedding_table,output, etc. by reference rather than by value. While this works correctly in the current code because each lambda is consumed immediately within_profile_pathbefore the loop continues, it's a pattern that can cause subtle bugs if the code is refactored (e.g., if lambdas were stored and executed later).Consider using default arguments to capture values explicitly. This pattern repeats throughout the benchmark loop (lines 398, 409, 460, 471, 532, 549).
🔧 Proposed fix (example for embedding_lookup)
kernel_avg_us = _profile_path( "embedding_lookup", - lambda: _run_embedding_lookup_once(input_ids, embedding_table), + lambda ids=input_ids, table=embedding_table: _run_embedding_lookup_once(ids, table), warmup=args.warmup, iters=args.iters, row_limit=args.row_limit, trace_path=kernel_trace, ) torch_avg_us = _profile_path( "torch_embedding_lookup", - lambda: _run_torch_embedding_lookup_once( - input_ids, embedding_table, output - ), + lambda ids=input_ids, table=embedding_table, out=output: _run_torch_embedding_lookup_once( + ids, table, out + ), warmup=args.warmup, iters=args.iters, row_limit=args.row_limit, trace_path=torch_trace, )Alternatively, use
functools.partialfor cleaner syntax:from functools import partial # ... lambda: _run_embedding_lookup_once(input_ids, embedding_table) # becomes: partial(_run_embedding_lookup_once, input_ids, embedding_table)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mllm-kernel/benchmarks/bench_vocab_embedding.py` around lines 337 - 356, The lambdas passed to _profile_path capture loop variables (e.g., input_ids, embedding_table, output) by reference, which risks subtle bugs if refactored; update the call sites that pass lambda: _run_embedding_lookup_once(input_ids, embedding_table) and lambda: _run_torch_embedding_lookup_once(input_ids, embedding_table, output) to capture values explicitly (either by using default args in the lambda, e.g. lambda input_ids=input_ids, embedding_table=embedding_table: _run_embedding_lookup_once(input_ids, embedding_table), or by using functools.partial to bind arguments: partial(_run_embedding_lookup_once, input_ids, embedding_table)); apply the same fix to the other similar lambda usages in the benchmark loop that call _run_* helper functions so each closure binds values at creation time.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@mllm-kernel/benchmarks/bench_vocab_embedding.py`:
- Around line 269-272: The choices list for the argument uses concatenation
["all"] + ALL_OPS; replace that with list unpacking to be cleaner and idiomatic
by using ["all", *ALL_OPS] where the argument is defined (the choices= parameter
in the add_argument call that references ALL_OPS).
- Around line 1-6: This script file (bench_vocab_embedding.py) is missing a
shebang; make the file executable by adding a Unix shebang line (e.g., for
python3) at the very top of the file before the module docstring so the script
can be run directly from the shell; ensure the shebang is the first line and
preserved when committing.
- Around line 337-356: The lambdas passed to _profile_path capture loop
variables (e.g., input_ids, embedding_table, output) by reference, which risks
subtle bugs if refactored; update the call sites that pass lambda:
_run_embedding_lookup_once(input_ids, embedding_table) and lambda:
_run_torch_embedding_lookup_once(input_ids, embedding_table, output) to capture
values explicitly (either by using default args in the lambda, e.g. lambda
input_ids=input_ids, embedding_table=embedding_table:
_run_embedding_lookup_once(input_ids, embedding_table), or by using
functools.partial to bind arguments: partial(_run_embedding_lookup_once,
input_ids, embedding_table)); apply the same fix to the other similar lambda
usages in the benchmark loop that call _run_* helper functions so each closure
binds values at creation time.
Add four vectorized CUDA embedding kernels:
All 17 tests passed.
Average speedup: 7.66x (1024), 5.79x (4096), 5.99x (8192).