diff --git a/.claude/skills/impl-jit-kernel/SKILL.md b/.claude/skills/impl-jit-kernel/SKILL.md new file mode 100644 index 000000000..39cc02b6f --- /dev/null +++ b/.claude/skills/impl-jit-kernel/SKILL.md @@ -0,0 +1,486 @@ +--- +name: impl-jit-kernel +description: Guide for implementing CUDA or CPU JIT kernels in mllm-kernel. Use when the user asks to create, add, or implement a new kernel in mllm-kernel. +--- + +# Implementing a JIT Kernel in mllm-kernel + +## Overview + +mllm-kernel uses a JIT (Just-In-Time) compilation system built on `tvm_ffi`. Kernels are written in C++20 (`.cuh` for CUDA, `.cpp` for CPU), validated at runtime via `TensorMatcher`, and exposed to Python through a `@jit` decorator. No pre-compilation is needed -- kernels compile on first call and are cached at `~/.cache/mllm_kernel/`. + +## File Layout + +For a kernel named `my_kernel`: + +``` +mllm-kernel/ + mllm_kernel/ + cuda/ + csrc/my_kernel.cuh # CUDA kernel implementation + jit/my_kernel.py # Python JIT wrapper + jit/__init__.py # Add export here + cpu/ + csrc/my_kernel.cpp # CPU kernel implementation (Highway SIMD) + include/mllm_kernel/cpu/ + my_kernel.hpp # CPU SIMD body (NO #pragma once) + jit/my_kernel.py # Python JIT wrapper + jit/__init__.py # Add export here + tests/test_my_kernel.py # Pytest correctness tests + benchmarks/bench_my_kernel.py # Profiler benchmark vs PyTorch reference +``` + +--- + +## CUDA Kernel Walkthrough + +### Step 1: Write the `.cuh` kernel + +Create `mllm_kernel/cuda/csrc/my_kernel.cuh`: + +```cpp +#pragma once + +#include // TensorMatcher, SymbolicSize, SymbolicDevice, SymbolicDType +#include // RuntimeCheck, Panic, div_ceil +#include // LaunchKernel, fp16_t, bf16_t, PDL helpers + +#include +#include + +#include + +namespace { + +// --------------------------------------------------------------------------- +// 1. Parameter struct (trivially copyable, passed to kernel by value) +// --------------------------------------------------------------------------- +struct MyKernelParams { + const float* __restrict__ input; + float* __restrict__ output; + int32_t num_elements; +}; + +// --------------------------------------------------------------------------- +// 2. CUDA kernel +// --------------------------------------------------------------------------- +__global__ void my_kernel(const MyKernelParams params) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= params.num_elements) return; + params.output[idx] = params.input[idx] * 2.0f; +} + +// --------------------------------------------------------------------------- +// 3. Host-side launcher (entry point for TVM FFI binding) +// --------------------------------------------------------------------------- +struct MyKernel { + static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView output) { + using namespace mllm_kernel::host; + + // --- Validate tensors --- + SymbolicSize N{"num_elements"}; + SymbolicDevice device; + + (void)TensorMatcher({N}) + .with_dtype() + .with_device(device) + .verify(input); + + (void)TensorMatcher({N}) + .with_dtype() + .with_device(device) + .verify(output); + + const int64_t n = N.unwrap(); + RuntimeCheck(n > 0, "num_elements must be positive, got ", n); + + // --- Build params --- + MyKernelParams params{ + .input = static_cast(input.data_ptr()), + .output = static_cast(output.data_ptr()), + .num_elements = static_cast(n), + }; + + // --- Launch --- + constexpr int kBlock = 256; + const int grid = static_cast(div_ceil(n, kBlock)); + LaunchKernel(grid, kBlock, device.unwrap())(my_kernel, params); + } +}; + +} // namespace +``` + +**Key rules:** + +- **Always wrap in `namespace {}`** (anonymous namespace). +- **Entry point** is a `static void run(tvm::ffi::TensorView ...)` method. +- **Validate every tensor** with `TensorMatcher` before reading `.data_ptr()`. +- **Never dereference device pointers on host** -- `data_ptr()` returns a GPU pointer. +- **Use `LaunchKernel`** to launch -- it handles stream resolution and error checking. + +### Step 2: Write the Python JIT wrapper + +Create `mllm_kernel/cuda/jit/my_kernel.py`: + +```python +"""JIT wrapper for my_kernel CUDA kernel.""" + +import torch +from mllm_kernel.jit_utils import jit + + +@jit( + args=[], + device="cuda", + cuda_files=["my_kernel.cuh"], + cpp_wrappers=[], + cuda_wrappers=[("my_kernel", "MyKernel::run")], + func_name="my_kernel", +) +def _kernel(compiled_module, input: torch.Tensor, output: torch.Tensor) -> None: + compiled_module.my_kernel(input, output) + + +def my_kernel(input: torch.Tensor) -> torch.Tensor: + """Double every element in *input*. + + Parameters + ---------- + input : torch.Tensor + 1-D float32 tensor on CUDA. + + Returns + ------- + torch.Tensor + Same shape and dtype as *input*. + """ + output = torch.empty_like(input) + _kernel(input, output) + return output +``` + +### Step 3: Export in `__init__.py` + +Edit `mllm_kernel/cuda/jit/__init__.py` and add: + +```python +from mllm_kernel.cuda.jit.my_kernel import my_kernel +``` + +### Step 4: Clear JIT cache after editing `.cuh` + +Any time you modify the `.cuh` file, delete the cached `.so`: + +```bash +rm -rf ~/.cache/mllm_kernel/cuda_my_kernel* +``` + +The next Python call will trigger recompilation automatically. + +--- + +## Template-Parameterized CUDA Kernels + +When the kernel takes compile-time constants (e.g. block size, dtype), use `make_cpp_args`: + +```python +from mllm_kernel.jit_utils import jit, make_cpp_args + +def _make_kernel(block_size: int, use_pdl: bool): + cpp_args = make_cpp_args(block_size, use_pdl) # -> "256, true" + + @jit( + args=[block_size, use_pdl], + device="cuda", + cuda_files=["my_kernel.cuh"], + cpp_wrappers=[], + cuda_wrappers=[("my_kernel", f"MyKernel<{cpp_args}>::run")], + func_name="my_kernel", + ) + def _kernel(compiled_module, input, output): + compiled_module.my_kernel(input, output) + return _kernel +``` + +`make_cpp_args` converts Python types to C++ literals: +- `int/float` -> string literal +- `bool` -> `"true"` / `"false"` +- `torch.dtype` -> C++ type (`torch.float32` -> `"fp32_t"`, `torch.float16` -> `"fp16_t"`, `torch.bfloat16` -> `"bf16_t"`, `torch.int32` -> `"int32_t"`, etc.) + +--- + +## CPU Kernel Walkthrough + +CPU kernels use **Google Highway** for portable SIMD. The key difference: the `.hpp` body is included **multiple times** by Highway's `foreach_target` dispatch, so it must NOT have `#pragma once`. + +### Step 1: Write the SIMD body (`.hpp`) + +Create `mllm_kernel/cpu/include/mllm_kernel/cpu/my_kernel.hpp`: + +```cpp +// NOTE: NO #pragma once -- this file is included multiple times by Highway. + +#include + +HWY_BEFORE_NAMESPACE(); +namespace mllm_kernel::cpu { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +template +inline void my_kernel_impl(float* HWY_RESTRICT dst, + const float* HWY_RESTRICT src, + size_t count) { + const hn::ScalableTag d; + const size_t lanes = hn::Lanes(d); + const auto vc = hn::Set(d, static_cast(Constant)); + size_t i = 0; + for (; i + lanes <= count; i += lanes) { + const auto v = hn::Load(d, src + i); + hn::Store(hn::Add(v, vc), d, dst + i); + } + for (; i < count; ++i) { + dst[i] = src[i] + static_cast(Constant); + } +} + +// Named entry points for HWY_EXPORT +static HWY_NOINLINE HWY_MAYBE_UNUSED void my_kernel_1(float* d, const float* s, size_t n) { + my_kernel_impl<1>(d, s, n); +} + +} // namespace HWY_NAMESPACE +} // namespace mllm_kernel::cpu +HWY_AFTER_NAMESPACE(); +``` + +### Step 2: Write the `.cpp` source + +Create `mllm_kernel/cpu/csrc/my_kernel.cpp`: + +```cpp +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "../csrc/my_kernel.cpp" +#include + +#include + +#if HWY_ONCE +#include +#endif + +namespace mllm_kernel::cpu { +#if HWY_ONCE + +HWY_EXPORT(my_kernel_1); + +template +void my_kernel(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) { + using namespace mllm_kernel::host; + SymbolicSize N{"num_elements"}; + SymbolicDevice device_; + (void)TensorMatcher({N}) + .with_dtype() + .with_device(device_) + .verify(dst) + .verify(src); + const size_t n = N.unwrap(); + auto* dst_ptr = static_cast(dst.data_ptr()); + const auto* src_ptr = static_cast(src.data_ptr()); + HWY_DYNAMIC_DISPATCH(my_kernel_1)(dst_ptr, src_ptr, n); +} + +// Explicit instantiation +template void my_kernel<1>(tvm::ffi::TensorView, tvm::ffi::TensorView); + +#endif +} // namespace mllm_kernel::cpu +``` + +### Step 3: Write the Python JIT wrapper + +Create `mllm_kernel/cpu/jit/my_kernel.py`: + +```python +import torch +from mllm_kernel.jit_utils import jit + +@jit( + args=1, + device="cpu", + cpp_files=["my_kernel.cpp"], + cpp_wrappers=[("my_kernel", "mllm_kernel::cpu::my_kernel<1>")], + func_name="my_kernel", +) +def _kernel_1(compiled_module, dst, src): + compiled_module.my_kernel(dst, src) + +def my_kernel(src: torch.Tensor) -> torch.Tensor: + dst = torch.empty_like(src) + _kernel_1(dst, src) + return dst +``` + +**Key CPU differences from CUDA:** + +| Aspect | CUDA | CPU | +|--------|------|-----| +| Source file | `.cuh` in `cuda/csrc/` | `.cpp` + `.hpp` in `cpu/csrc/` and `cpu/include/` | +| Namespace | Anonymous `namespace {}` | `mllm_kernel::cpu` | +| Device check | `with_device` | `with_device` | +| Launch | `LaunchKernel(grid, block, device)(...)` | Direct function call via `HWY_DYNAMIC_DISPATCH` | +| SIMD | CUDA warps | Highway `ScalableTag` | +| Wrapper fields | `cuda_files`, `cuda_wrappers` | `cpp_files`, `cpp_wrappers` | +| Wrapper name | `"MyKernel::run"` | `"mllm_kernel::cpu::my_kernel<1>"` (fully qualified) | + +--- + +## TensorMatcher Reference + +`TensorMatcher` validates shape, dtype, device, and strides of `tvm::ffi::TensorView` arguments. + +```cpp +using namespace mllm_kernel::host; + +// Symbolic dimensions -- bind on first .verify(), check consistency on subsequent calls +SymbolicSize B{"batch"}, N{"seq_len"}, D{"dim"}; +SymbolicSize Stride0{"stride0"}; +SymbolicDType dtype; +SymbolicDevice device; + +// Shape [B, N, D], contiguous, float32, on CUDA +(void)TensorMatcher({B, N, D}) + .with_dtype(dtype) + .with_device(device) + .verify(tensor_a); + +// Shape [B, N, D], same dtype and device (already bound) +(void)TensorMatcher({B, N, D}) + .with_dtype(dtype) + .with_device(device) + .verify(tensor_b); + +// Shape [B, D] with explicit strides (non-contiguous OK) +(void)TensorMatcher({B, D}) + .with_strides({Stride0, 1}) + .with_dtype() + .with_device(device) + .verify(indices); + +// Multiple acceptable dtypes +SymbolicDType flex_dtype; +(void)TensorMatcher({N}) + .with_dtype(flex_dtype) + .with_device(device) + .verify(mixed_tensor); + +// Extract bound values +int64_t batch = B.unwrap(); +int64_t dim = D.unwrap(); +DLDevice dev = device.unwrap(); +``` + +--- + +## LaunchKernel Reference + +```cpp +using namespace mllm_kernel::host; + +// Basic launch (resolves CUDA stream from DLDevice) +DLDevice dev = device.unwrap(); +LaunchKernel(grid_dim, block_dim, dev)(kernel_func, param_struct); + +// With shared memory +LaunchKernel(grid, block, dev, shared_mem_bytes)(kernel, params); + +// With PDL (Programmatic Dependent Launch, sm_90+) +LaunchKernel(grid, block, dev).enable_pdl(true)(kernel, params); +``` + +--- + +## Utility Reference (`mllm_kernel::host`) + +| Function | Description | +|----------|-------------| +| `RuntimeCheck(cond, msg...)` | Throws `PanicError` if `cond` is false | +| `Panic(msg...)` | Always throws (unreachable code) | +| `div_ceil(a, b)` | Integer ceiling division | +| `dtype_bytes(DLDataType)` | Byte size of a DLPack dtype | + +CUDA-only (`mllm_kernel::device`): + +| Symbol | Value | +|--------|-------| +| `kWarpThreads` | 32 | +| `kFullMask` | 0xffffffff | +| `fp16_t` | `__half` | +| `bf16_t` | `__nv_bfloat16` | + +--- + +## Testing Pattern + +Create `tests/test_my_kernel.py`: + +```python +import pytest +import torch +from mllm_kernel.cuda.jit.my_kernel import my_kernel + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("n", [1, 128, 1024, 65536]) +def test_my_kernel(n): + x = torch.randn(n, dtype=torch.float32, device="cuda") + result = my_kernel(x) + torch.cuda.synchronize() + expected = x * 2.0 + assert torch.allclose(result, expected) +``` + +Run: +```bash +pytest tests/test_my_kernel.py -v +``` + +--- + +## Benchmark Pattern + +Create `benchmarks/bench_my_kernel.py`. Use `torch.profiler.profile` with `ProfilerActivity.CPU` and `ProfilerActivity.CUDA`. Compare the JIT kernel against a naive PyTorch implementation and report speedup. + +Run: +```bash +python benchmarks/bench_my_kernel.py --num-elements 1000000 +``` + +--- + +## Checklist for a New Kernel + +- [ ] `.cuh` / `.cpp` + `.hpp` kernel source created +- [ ] `TensorMatcher` validates all tensor arguments (shape, dtype, device) +- [ ] No host-side dereference of device pointers +- [ ] Python `@jit` wrapper created with correct `cuda_wrappers` or `cpp_wrappers` +- [ ] Public API function added (allocates output, calls internal `_kernel`) +- [ ] Exported in `jit/__init__.py` +- [ ] JIT cache cleared after `.cuh` edits (`rm -rf ~/.cache/mllm_kernel/cuda_*`) +- [ ] Pytest test with `@pytest.mark.parametrize` and PyTorch reference +- [ ] Benchmark with `torch.profiler` (optional but recommended) + +--- + +## Common Pitfalls + +1. **Segfault from dereferencing device pointer on host** -- `tensor.data_ptr()` returns a GPU pointer for CUDA tensors. Never read its contents in host code. Use `TensorMatcher` for validation instead. +2. **Stale JIT cache** -- After editing `.cuh`, delete `~/.cache/mllm_kernel/cuda_*/`. The old `.so` will be reused otherwise. +3. **Missing `#include `** -- CPU kernels must include this inside `#if HWY_ONCE` to provide `GetChosenTarget` for the JIT-built module. +4. **`#pragma once` in Highway `.hpp`** -- Highway's `foreach_target` includes the file multiple times for different SIMD targets. `#pragma once` breaks this. +5. **Wrong wrapper name** -- CUDA uses short names (`"MyKernel::run"`); CPU uses fully qualified names (`"mllm_kernel::cpu::my_kernel<1>"`). +6. **Generator device mismatch in tests** -- `torch.randperm` needs a CUDA generator on CUDA; `torch.randint` only accepts CPU generators. Use separate generators. diff --git a/.claude/skills/update-codeowners/SKILL.md b/.claude/skills/update-codeowners/SKILL.md new file mode 100644 index 000000000..286667045 --- /dev/null +++ b/.claude/skills/update-codeowners/SKILL.md @@ -0,0 +1,44 @@ +--- +name: update-codeowners +description: Updates CODEOWNERS entries safely with consistent path and owner formatting. Use when the user asks to add, remove, or modify CODEOWNERS rules, ownership mappings, reviewers, or module maintainers. +--- + +# Update CODEOWNERS + +## Goal +Maintain `CODEOWNERS` accurately while preserving the repository's existing section/comment style. + +## Workflow +1. Read the current `CODEOWNERS` file before editing. +2. Identify requested changes as one of: + - Add new path rule + - Modify owners for existing path rule + - Remove obsolete path rule + - Reorganize section comments (only if requested) +3. Update rules in place instead of creating duplicates for the same path. +4. Keep existing section headers and comment style unless the user asks to refactor structure. +5. Return a concise changelog describing which paths were added, changed, or removed. + +## Rule Format +- Use one rule per line: ` ...` +- Owners must be GitHub handles prefixed with `@`. +- Keep path style consistent with the file (in this repo, path patterns typically start with `/`). +- Do not leave rules with empty owner lists. + +## Editing Guidelines +- Prefer minimal edits near related sections. +- If a path already exists, update that line instead of adding a second conflicting line. +- If a new rule logically belongs to an existing section, place it in that section. +- Preserve human-readable grouping and blank lines. +- Keep comments intact unless they are clearly outdated and the user asked for cleanup. + +## Validation Checklist +- [ ] Every non-comment, non-empty line has at least one owner. +- [ ] Every owner token starts with `@`. +- [ ] No accidental duplicate rule for the exact same path pattern. +- [ ] Existing comments/sections were preserved unless explicitly changed. + +## Example Requests +- "Add `/mllm/models/new_model/ @alice @bob` under models." +- "Change `/core/Storage` owner to `@team-core`." +- "Remove ownership rule for deprecated path `/legacy/`." diff --git a/.codespellrc b/.codespellrc index 9ddb9d851..bbf02bd17 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ [codespell] -ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS, bfloat, constexpr, cuda, dlpack, expt, forceinline, ifndef, linalg, LPBQ, mllm, pymllm, Quantizaton, Qwen, ROCM, silu, torchao +ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS, bfloat, constexpr, cuda, dlpack, expt, forceinline, ifndef, linalg, LPBQ, mllm, pymllm, Quantizaton, Qwen, ROCM, silu, torchao, flashinfer skip = *.json,*.jsonl,*.patch,*.txt diff --git a/.gitignore b/.gitignore index 7397d6ecc..7f14b37ec 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ .cache/ .tmp/ compile_commands.json -.claude/ +settings.local.json # MLLM Team Specific tasks/mllmteam* diff --git a/assets/pymllm-arch.png b/assets/pymllm-arch.png new file mode 100644 index 000000000..37c48b2a0 Binary files /dev/null and b/assets/pymllm-arch.png differ diff --git a/mllm-kernel/.gitignore b/mllm-kernel/.gitignore index df61d0fae..3eefc8fba 100644 --- a/mllm-kernel/.gitignore +++ b/mllm-kernel/.gitignore @@ -3,3 +3,4 @@ build-py/ .vscode/settings.json compile_commands.json .clangd +.pytest_cache/ diff --git a/mllm-kernel/README.md b/mllm-kernel/README.md index 14c8118f0..0a4580495 100644 --- a/mllm-kernel/README.md +++ b/mllm-kernel/README.md @@ -80,31 +80,30 @@ y = add_constant(x, 8) Use the helpers in `mllm_kernel.jit_utils`: -- `load_cpu_jit` -- `load_cuda_jit` +- `jit` - `make_cpp_args` -- `cache_once` -Example pattern: +Recommended pattern (CPU example): ```python import torch -from mllm_kernel.jit_utils import cache_once, load_cpu_jit, make_cpp_args - -@cache_once -def _jit_my_kernel_module(param: int): - args = make_cpp_args(param) - return load_cpu_jit( - "my_kernel", - *args, - cpp_files=["my_kernel.cpp"], - cpp_wrappers=[("my_kernel", f"my_namespace::my_kernel<{args}>")], - ) +import mllm_kernel + +@mllm_kernel.jit( + args=16, + device="cpu", + cpp_files=["my_kernel.cpp"], + cpp_wrappers=[("my_kernel", "my_namespace::my_kernel<16>")], + func_name="my_kernel", +) +def _my_kernel_16(compiled_module, dst: torch.Tensor, src: torch.Tensor) -> None: + compiled_module.my_kernel(dst, src) def my_kernel(src: torch.Tensor, param: int) -> torch.Tensor: + if param != 16: + raise ValueError("This demo only supports param=16.") dst = torch.empty_like(src) - module = _jit_my_kernel_module(param) - module.my_kernel(dst, src) + _my_kernel_16(dst, src) return dst ``` diff --git a/mllm-kernel/benchmarks/bench_create_kv_indices.py b/mllm-kernel/benchmarks/bench_create_kv_indices.py new file mode 100644 index 000000000..f570e66de --- /dev/null +++ b/mllm-kernel/benchmarks/bench_create_kv_indices.py @@ -0,0 +1,218 @@ +"""Benchmark create_kv_indices vs naive torch gather using torch.profiler. + +Example: + python benchmarks/bench_create_kv_indices.py --batch-size 512 --max-reqs 2048 --max-ctx 4096 +""" + +from __future__ import annotations + +import argparse + +import torch +from torch.profiler import ProfilerActivity, profile + +from mllm_kernel.cuda.jit.create_kv_indices import create_kv_indices + + +def _make_batch( + *, + max_reqs: int, + max_ctx: int, + batch_size: int, + use_start_offsets: bool, + device: torch.device, + seed: int, +): + g_cuda = torch.Generator(device=device).manual_seed(seed) + g_cpu = torch.Generator(device="cpu").manual_seed(seed) + + req_to_token = torch.arange( + max_reqs * max_ctx, dtype=torch.int32, device=device + ).reshape(max_reqs, max_ctx) + + assert batch_size <= max_reqs + req_pool_indices = torch.randperm(max_reqs, generator=g_cuda, device=device)[ + :batch_size + ].to(torch.int32) + + page_kernel_lens_list = [] + kv_start_idx_list = [] + for _ in range(batch_size): + L = int(torch.randint(1, max_ctx, (1,), generator=g_cpu).item()) + if use_start_offsets: + start_max = max_ctx - L + start = int(torch.randint(0, max(start_max, 1), (1,), generator=g_cpu).item()) + else: + start = 0 + page_kernel_lens_list.append(L) + kv_start_idx_list.append(start) + + page_kernel_lens = torch.tensor( + page_kernel_lens_list, dtype=torch.int32, device=device + ) + kv_start_idx = torch.tensor(kv_start_idx_list, dtype=torch.int32, device=device) + + kv_indptr = torch.empty(batch_size + 1, dtype=torch.int32, device=device) + kv_indptr[0] = 0 + kv_indptr[1:] = torch.cumsum(page_kernel_lens, dim=0) + + kv_indices = torch.empty( + int(kv_indptr[-1].item()), dtype=torch.int32, device=device + ) + + return ( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) + + +def _profile( + name: str, fn, *, warmup: int, iters: int, row_limit: int, trace_path: str | None +): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=False, + profile_memory=False, + with_stack=False, + ) as prof: + for _ in range(iters): + fn() + torch.cuda.synchronize() + + events = prof.key_averages() + time_attr = ( + "self_cuda_time_total" + if events and hasattr(events[0], "self_cuda_time_total") + else "self_device_time_total" + ) + sort_key = ( + "self_cuda_time_total" + if time_attr == "self_cuda_time_total" + else "self_device_time_total" + ) + total_us = sum(float(getattr(evt, time_attr, 0.0)) for evt in events) + avg_us = total_us / max(iters, 1) + + print(f"\n=== {name} ===") + print( + prof.key_averages().table( + sort_by=sort_key, + row_limit=row_limit, + ) + ) + print(f"{name} total self device time: {total_us:.2f} us") + print(f"{name} avg self device time/iter: {avg_us:.2f} us") + + if trace_path: + prof.export_chrome_trace(trace_path) + print(f"{name} trace exported: {trace_path}") + + return avg_us + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark create_kv_indices vs naive torch gather", + ) + parser.add_argument("--batch-size", type=int, default=512) + parser.add_argument("--max-reqs", type=int, default=2048) + parser.add_argument("--max-ctx", type=int, default=4096) + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument("--iters", type=int, default=200) + parser.add_argument("--row-limit", type=int, default=20) + parser.add_argument("--export-trace-dir", type=str, default="") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--use-start-offsets", + action="store_true", + help="Enable non-zero kv_start_idx to emulate sliding-window decode", + ) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this benchmark") + + torch.manual_seed(args.seed) + device = torch.device("cuda") + + ( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) = _make_batch( + max_reqs=args.max_reqs, + max_ctx=args.max_ctx, + batch_size=args.batch_size, + use_start_offsets=args.use_start_offsets, + device=device, + seed=args.seed, + ) + + print("=== create_kv_indices profiler benchmark ===") + print( + f"batch_size={args.batch_size}, max_reqs={args.max_reqs}, max_ctx={args.max_ctx}, " + f"use_start_offsets={args.use_start_offsets}" + ) + print(f"warmup={args.warmup}, iters={args.iters}, row_limit={args.row_limit}") + + trace_dir = args.export_trace_dir.strip() + kernel_trace = f"{trace_dir}/create_kv_indices_trace.json" if trace_dir else None + torch_trace = f"{trace_dir}/torch_gather_trace.json" if trace_dir else None + + def _run_kernel_once(): + create_kv_indices( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) + + def _run_torch_once(): + # Torch reference implementation on device: gather per-sequence ranges + # from req_to_token into a flat buffer. + out = [] + for i in range(args.batch_size): + req = req_pool_indices[i].item() + start = kv_start_idx[i].item() if args.use_start_offsets else 0 + L = page_kernel_lens[i].item() + row = req_to_token[req, start : start + L] + out.append(row) + torch.cat(out, out=kv_indices) + + kernel_avg_us = _profile( + "create_kv_indices", + _run_kernel_once, + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=kernel_trace, + ) + + torch_avg_us = _profile( + "torch_reference", + _run_torch_once, + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=torch_trace, + ) + + speedup = torch_avg_us / max(kernel_avg_us, 1e-12) + print(f"\nSpeedup: {speedup:.3f}x") + + +if __name__ == "__main__": + main() diff --git a/mllm-kernel/benchmarks/bench_store_cache.py b/mllm-kernel/benchmarks/bench_store_cache.py new file mode 100644 index 000000000..b96fa608b --- /dev/null +++ b/mllm-kernel/benchmarks/bench_store_cache.py @@ -0,0 +1,164 @@ +"""Benchmark store_cache vs torch index with torch.profiler. + +Example: +python benchmarks/bench_store_cache.py --warmup 20 --iters 200 --batch-size 512 --num-slots 8192 +""" + +import argparse + +import torch +from torch.profiler import ProfilerActivity, profile + +from mllm_kernel.cuda.jit import can_use_store_cache, store_cache + + +def _run_store_cache_once( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +): + store_cache(k, v, k_cache, v_cache, indices) + + +def _run_torch_index_once( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +): + k_cache[indices] = k + v_cache[indices] = v + + +def _profile_path( + name: str, + fn, + *, + warmup: int, + iters: int, + row_limit: int, + trace_path: str | None, +): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=False, + profile_memory=False, + with_stack=False, + ) as prof: + for _ in range(iters): + fn() + torch.cuda.synchronize() + + events = prof.key_averages() + # torch profiler times are in microseconds. + # PyTorch versions vary between *cuda* and *device* naming. + time_attr = ( + "self_cuda_time_total" + if events and hasattr(events[0], "self_cuda_time_total") + else "self_device_time_total" + ) + sort_key = ( + "self_cuda_time_total" + if time_attr == "self_cuda_time_total" + else "self_device_time_total" + ) + total_self_device_us = sum(float(getattr(evt, time_attr, 0.0)) for evt in events) + avg_self_device_us = total_self_device_us / max(iters, 1) + + print(f"\n=== {name} ===") + print( + prof.key_averages().table( + sort_by=sort_key, + row_limit=row_limit, + ) + ) + print(f"{name} total self device time: {total_self_device_us:.2f} us") + print(f"{name} avg self device time/iter: {avg_self_device_us:.2f} us") + + if trace_path: + prof.export_chrome_trace(trace_path) + print(f"{name} trace exported: {trace_path}") + + return avg_self_device_us + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark store_cache vs torch index using torch.profiler" + ) + parser.add_argument("--batch-size", type=int, default=1024) + parser.add_argument("--num-slots", type=int, default=16384) + parser.add_argument("--head-num", type=int, default=8) + parser.add_argument("--head-dim", type=int, default=128) + parser.add_argument( + "--dtype", + type=str, + default="float16", + choices=["float16", "bfloat16", "float32"], + ) + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument("--iters", type=int, default=200) + parser.add_argument("--row-limit", type=int, default=20) + parser.add_argument("--export-trace-dir", type=str, default="") + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this benchmark") + + torch.manual_seed(args.seed) + device = torch.device("cuda") + dtype = getattr(torch, args.dtype) + + row_dim = args.head_num * args.head_dim + row_bytes = row_dim * torch.tensor([], dtype=dtype).element_size() + if not can_use_store_cache(row_bytes): + raise RuntimeError(f"store_cache is unavailable for row_bytes={row_bytes}") + + k = torch.randn(args.batch_size, row_dim, device=device, dtype=dtype) + v = torch.randn(args.batch_size, row_dim, device=device, dtype=dtype) + # Use unique indices to avoid write conflicts. + indices = torch.randperm(args.num_slots, device=device)[: args.batch_size].to( + torch.int64 + ) + k_cache = torch.zeros(args.num_slots, row_dim, device=device, dtype=dtype) + v_cache = torch.zeros_like(k_cache) + print("=== store_cache profiler benchmark ===") + print( + f"shape: batch={args.batch_size}, row_dim={row_dim}, slots={args.num_slots}, dtype={dtype}" + ) + print(f"warmup={args.warmup}, iters={args.iters}, row_limit={args.row_limit}") + + trace_dir = args.export_trace_dir.strip() + store_trace = f"{trace_dir}/store_cache_trace.json" if trace_dir else None + torch_trace = f"{trace_dir}/torch_index_trace.json" if trace_dir else None + + store_avg_us = _profile_path( + "store_cache", + lambda: _run_store_cache_once(k, v, k_cache, v_cache, indices), + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=store_trace, + ) + torch_avg_us = _profile_path( + "torch_index", + lambda: _run_torch_index_once(k, v, k_cache, v_cache, indices), + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=torch_trace, + ) + speedup = torch_avg_us / max(store_avg_us, 1e-12) + print(f"\nSpeedup: {speedup:.3f}x") + + +if __name__ == "__main__": + main() diff --git a/mllm-kernel/mllm_kernel/__main__.py b/mllm-kernel/mllm_kernel/__main__.py index d4888b86c..e5f0779d6 100644 --- a/mllm-kernel/mllm_kernel/__main__.py +++ b/mllm-kernel/mllm_kernel/__main__.py @@ -388,7 +388,7 @@ def main() -> None: logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") parser = argparse.ArgumentParser( - prog="python -m mllm_kernel", + prog="mllm_kernel", description="mllm-kernel helper commands.", ) parser.add_argument( diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/create_kv_indices.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/create_kv_indices.cuh new file mode 100644 index 000000000..0b9e4c888 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/create_kv_indices.cuh @@ -0,0 +1,282 @@ +// High-performance CUDA kernel to build FlashInfer KV index arrays from +// pymllm's ReqToTokenPool mapping table. +// +// This is the CUDA-C equivalent of the Triton kernel +// `_create_kv_indices_triton` previously defined in +// `pymllm/layers/attention/flashinfer_backend.py`. +// +// Motivation +// ---------- +// FlashInfer's paged KV attention API expects a *flat* buffer of KV indices +// (`kv_indices`) together with a prefix-sum pointer array (`kv_indptr`). +// +// * `kv_indices` is a 1-D int32 array that stores, for every token of every +// sequence in a batch, the corresponding *slot index* in the KV cache. +// * `kv_indptr` (length = batch_size + 1) stores prefix sums over the +// per-sequence token counts. For sequence `i` we have tokens in: +// +// kv_indices[kv_indptr[i] : kv_indptr[i + 1]] +// +// In pymllm, the mapping from (request_slot, position_in_sequence) to KV slot +// index is stored in a 2-D tensor `req_to_token` owned by `ReqToTokenPool`: +// +// req_to_token[req_slot, position] -> kv_index (int32) +// +// For each batch we also know: +// * which request slots we are serving: `req_pool_indices[bs]` +// * how many tokens to use from each sequence: `page_kernel_lens[bs]` +// * the starting position inside each sequence: `kv_start_idx[bs]` (optional, +// used for sliding-window / partial-context attention) +// +// This kernel converts that 2-D layout into the flat `(kv_indptr, kv_indices)` +// layout in a single, highly parallel CUDA pass: +// +// For each sequence i in the batch: +// - let req = req_pool_indices[i] +// - let len = page_kernel_lens[i] +// - let start = kv_start_idx[i] (or 0 if not provided) +// - let offset = kv_indptr[i] +// - for j in [0, len): +// kv_indices[offset + j] = req_to_token[req, start + j] +// +// Requirements / invariants +// ------------------------- +// * `req_to_token` is int32 (aligned with sglang). +// * All tensors must reside on the same CUDA device. +// * The kernel is designed for extremely high throughput: +// - a block is assigned per sequence (batch element), +// - threads cooperate within the block to copy the token range with +// coalesced loads/stores. +// * Shape and dtype checks are performed at runtime via mllm_kernel's +// TensorMatcher utilities, so misuse is caught with clear error messages. +// +// Integration +// ----------- +// The exported entry point is `CreateKvIndicesKernel::run(...)`. The Python +// wrapper in `mllm_kernel/cuda/jit/create_kv_indices.py` JIT-compiles this +// kernel and exposes a `create_kv_indices(...)` function which is then called +// by `pymllm.layers.attention.flashinfer_backend`. + +#pragma once + +#include // TensorMatcher, SymbolicSize, SymbolicDevice, SymbolicDType +#include // div_ceil, RuntimeCheck, Panic +#include // LaunchKernel + +#include +#include + +#include + +namespace { + +// --------------------------------------------------------------------------- +// Parameter block passed to the CUDA kernel +// --------------------------------------------------------------------------- +// +// We keep this struct trivially-copyable so it can be passed via +// `__grid_constant__` if desired. Each field is carefully documented to make +// the data flow explicit. + +struct CreateKvIndicesParams { + // Pointer to ReqToTokenPool mapping table: + // req_to_token[req_slot, position] -> kv_index (int32) + // shape: [max_reqs, max_context_len] + const int32_t* __restrict__ req_to_token; + + // Request slots participating in this batch. + // shape: [batch_size] + const int32_t* __restrict__ req_pool_indices; + + // Number of tokens to copy for each sequence in the batch. + // shape: [batch_size] + const int32_t* __restrict__ page_kernel_lens; + + // Prefix sums over per-sequence token counts. + // kv_indptr[i] is the starting offset in kv_indices for sequence i. + // shape: [batch_size + 1] + const int32_t* __restrict__ kv_indptr; + + // Optional starting position inside each request's sequence. When nullptr, + // we assume start = 0 for all sequences. When non-null, shape is + // [batch_size]. + const int32_t* __restrict__ kv_start_idx; + + // Output flat KV index buffer (int32). Length must be at least + // kv_indptr[batch_size]. + int32_t* __restrict__ kv_indices; + + // Stride of the first dimension of req_to_token, i.e. the number of + // positions per request (max_context_len). + int32_t req_to_token_stride; + + // Number of sequences in the batch. + uint32_t batch_size; + + // Whether kv_start_idx is valid (1) or should be ignored (0). + uint32_t has_kv_start; +}; + +// We use a fixed block size chosen to balance occupancy and per-sequence +// parallelism. Each block is mapped to a single sequence and threads within +// the block cooperate to copy its token range. +constexpr int kBlockSize = 256; + +// --------------------------------------------------------------------------- +// Core CUDA kernel +// --------------------------------------------------------------------------- +// +// Grid mapping: +// * blockIdx.x -> sequence index `i` in [0, batch_size) +// * threadIdx.x -> intra-sequence worker; threads stride over the token +// range [0, len) with step `blockDim.x`. +// +// This design has several advantages: +// * No inter-block synchronisation is required. +// * Memory accesses are fully coalesced because each thread block walks a +// contiguous segment of the `req_to_token` and `kv_indices` arrays. +// * It handles variable-length sequences naturally; sequences with more +// tokens simply iterate more in the inner loop. + +__global__ void create_kv_indices_kernel(const CreateKvIndicesParams params) { + const uint32_t seq_id = blockIdx.x; // which sequence in the batch + if (seq_id >= params.batch_size) { return; } + + // Resolve the request slot for this sequence. + const int32_t req_slot = params.req_pool_indices[seq_id]; + + // Compute the output range [out_offset, out_offset + len) in kv_indices. + const int32_t out_offset = params.kv_indptr[seq_id]; + const int32_t len = params.page_kernel_lens[seq_id]; + + // Compute the starting position inside the original sequence. + int32_t start = 0; + if (params.has_kv_start && params.kv_start_idx != nullptr) { start = params.kv_start_idx[seq_id]; } + + // Base pointers for this sequence. + const int32_t* __restrict__ row = params.req_to_token + static_cast(req_slot) * params.req_to_token_stride; + int32_t* __restrict__ out = params.kv_indices + out_offset; + + // Each thread in the block handles a strided subset of [0, len). + for (int32_t t = threadIdx.x; t < len; t += blockDim.x) { + // Guard against out-of-bounds reads if (start + t) exceeds the + // configured context length. Under normal conditions upstream + // invariants guarantee `start + len <= req_to_token_stride`, but + // this check makes the kernel robust against misconfigured inputs + // and prevents rare segmentation faults observed during testing. + const int32_t pos = start + t; + if (pos < 0 || pos >= params.req_to_token_stride) { continue; } + + out[t] = row[pos]; + } +} + +// --------------------------------------------------------------------------- +// Host-side launcher used by the JIT wrapper +// --------------------------------------------------------------------------- +// +// `CreateKvIndicesKernel::run(...)` is the C++ entry point that will be bound +// to a TVM FFI function and called from Python via the JIT utility. It is +// responsible for: +// 1. Validating tensor shapes / dtypes / devices. +// 2. Extracting symbolic sizes and strides. +// 3. Building the parameter block. +// 4. Launching the CUDA kernel using mllm_kernel::host::LaunchKernel. + +struct CreateKvIndicesKernel { + static void run(tvm::ffi::TensorView req_to_token, tvm::ffi::TensorView req_pool_indices, + tvm::ffi::TensorView page_kernel_lens, tvm::ffi::TensorView kv_indptr, tvm::ffi::TensorView kv_start_idx, + tvm::ffi::TensorView kv_indices) { + using namespace mllm_kernel::host; + + // --------------------------------------------------------------------- + // 1. Validate input tensors + // --------------------------------------------------------------------- + // req_to_token: [max_reqs, max_context_len], int32, CUDA + SymbolicSize MaxReqs{"max_reqs"}; + SymbolicSize MaxCtx{"max_context_len"}; + SymbolicSize ReqStride{"req_stride"}; + SymbolicDType req_dtype; + SymbolicDevice device; + + (void)TensorMatcher({MaxReqs, MaxCtx}) + .with_strides({ReqStride, 1}) + .with_dtype(req_dtype) + .with_device(device) + .verify(req_to_token); + + // req_pool_indices: [B], int32, CUDA + SymbolicSize B{"batch_size"}; + SymbolicSize ReqPoolStride{"req_pool_stride"}; + (void)TensorMatcher({B}).with_strides({ReqPoolStride}).with_dtype().with_device(device).verify(req_pool_indices); + + // page_kernel_lens: [B], int32, same device + SymbolicSize PageStride{"page_stride"}; + (void)TensorMatcher({B}).with_strides({PageStride}).with_dtype().with_device(device).verify(page_kernel_lens); + + // kv_indptr: [Nind], int32, same device (we later require Nind >= B + 1) + SymbolicSize Nind{"indptr_len"}; + (void)TensorMatcher({Nind}).with_dtype().with_device(device).verify(kv_indptr); + + // kv_start_idx: either [B] or [0]; int32, same device + SymbolicSize StartLen{"start_len"}; + SymbolicSize StartStride{"start_stride"}; + (void)TensorMatcher({StartLen}).with_strides({StartStride}).with_dtype().with_device(device).verify(kv_start_idx); + + // kv_indices: [Nidx], int32, same device + SymbolicSize Nidx{"num_indices"}; + (void)TensorMatcher({Nidx}).with_dtype().with_device(device).verify(kv_indices); + + // Extract concrete sizes. + const int64_t batch_size = B.unwrap(); + const int64_t indptr_len = Nind.unwrap(); + const int64_t req_stride = ReqStride.unwrap(); + + // Basic consistency checks. + RuntimeCheck(batch_size > 0, "batch_size must be positive, got ", batch_size); + RuntimeCheck(indptr_len >= batch_size + 1, "kv_indptr length (", indptr_len, ") must be at least batch_size+1 (", + batch_size + 1, ")"); + + // NOTE: We intentionally do NOT read kv_indptr[batch_size] on the host to + // validate that kv_indices is large enough. kv_indptr resides in device + // memory and dereferencing it from host code would be an illegal memory + // access (segfault). Callers are responsible for ensuring that + // kv_indices.numel() >= kv_indptr[batch_size]. + + // kv_start_idx is optional; when StartLen == 0 we treat it as absent. + RuntimeCheck(StartLen.unwrap() == 0 || StartLen.unwrap() == batch_size, + "kv_start_idx must have length 0 or batch_size; got ", StartLen.unwrap(), " vs batch_size=", batch_size); + + const bool has_kv_start = (StartLen.unwrap() == batch_size); + + // --------------------------------------------------------------------- + // 2. Build parameter block + // --------------------------------------------------------------------- + CreateKvIndicesParams params{ + .req_to_token = static_cast(req_to_token.data_ptr()), + .req_pool_indices = static_cast(req_pool_indices.data_ptr()), + .page_kernel_lens = static_cast(page_kernel_lens.data_ptr()), + .kv_indptr = static_cast(kv_indptr.data_ptr()), + .kv_start_idx = has_kv_start ? static_cast(kv_start_idx.data_ptr()) : nullptr, + .kv_indices = static_cast(kv_indices.data_ptr()), + .req_to_token_stride = static_cast(req_stride), + .batch_size = static_cast(batch_size), + .has_kv_start = has_kv_start ? 1u : 0u, + }; + + const DLDevice dl_device = device.unwrap(); + + // --------------------------------------------------------------------- + // 3. Launch the CUDA kernel + // --------------------------------------------------------------------- + // We launch one block per sequence so that each sequence can be processed + // independently with fully coalesced memory accesses. The per-thread + // inner loop runs over the token range [0, len) with stride = blockDim.x. + + const int grid_size = static_cast(batch_size); + + LaunchKernel(grid_size, kBlockSize, dl_device)(create_kv_indices_kernel, params); + } +}; + +} // namespace diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/gdn_decode.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/gdn_decode.cuh new file mode 100644 index 000000000..4c2833c06 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/gdn_decode.cuh @@ -0,0 +1,432 @@ +// Fused GDN (Gated Delta Net) decode kernel for linear attention. +// +// Performs a single-token recurrent update per request: +// g = -exp(A_log) * softplus(a + dt_bias) +// beta = sigmoid(b) +// q = L2norm(q) * scale +// k = L2norm(k) +// state *= exp(g) (decay) +// v_delta = v - state @ k (delta rule) +// v_delta *= beta (gated update) +// state += v_delta outer k (state update) +// output = state @ q (readout) +// +// Works on SM80+ (Ampere, Jetson Orin, Hopper, ...). +// Matches the algorithm of sglang's fused_sigmoid_gating_delta_rule_update. +// +// Grid : (NV, bs * HV) where NV = ceil(V / BV) +// Block: BLOCK_K threads (one thread per K-dimension element) +// +// Each thread owns BV state elements at its K position. +// Two cross-thread reductions (over K) compute delta and output dot products. + +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +#include + +namespace GDNDecodeKernel { + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +inline constexpr int BV = 32; // V-dimension tile size + +// --------------------------------------------------------------------------- +// Warp-level reduction +// --------------------------------------------------------------------------- + +__device__ __forceinline__ float warp_reduce_sum(float val) { + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_xor_sync(0xffffffff, val, offset); + } + return val; +} + +// --------------------------------------------------------------------------- +// Type conversion helpers +// --------------------------------------------------------------------------- + +template +__device__ __forceinline__ float to_float(T val); + +template <> +__device__ __forceinline__ float to_float<__half>(__half val) { + return __half2float(val); +} + +template <> +__device__ __forceinline__ float to_float<__nv_bfloat16>(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__device__ __forceinline__ float to_float(float val) { + return val; +} + +template +__device__ __forceinline__ T from_float(float val); + +template <> +__device__ __forceinline__ __half from_float<__half>(float val) { + return __float2half(val); +} + +template <> +__device__ __forceinline__ __nv_bfloat16 from_float<__nv_bfloat16>(float val) { + return __float2bfloat16(val); +} + +template <> +__device__ __forceinline__ float from_float(float val) { + return val; +} + +// --------------------------------------------------------------------------- +// Block-level scalar reduction (sum across all threads → broadcast result) +// --------------------------------------------------------------------------- + +// Reduces a scalar across all threads in the block. +// Returns the sum in ALL threads (via shared memory broadcast). +// smem must have at least (blockDim.x / 32) floats. +__device__ __forceinline__ float block_reduce_sum(float val, float* smem) { + const int warp_id = threadIdx.x / 32; + const int lane_id = threadIdx.x % 32; + const int num_warps = blockDim.x / 32; + + val = warp_reduce_sum(val); + if (lane_id == 0) smem[warp_id] = val; + __syncthreads(); + + // First warp reduces across warps + if (warp_id == 0) { + float v = (lane_id < num_warps) ? smem[lane_id] : 0.0f; + v = warp_reduce_sum(v); + if (lane_id == 0) smem[0] = v; + } + __syncthreads(); + return smem[0]; +} + +// --------------------------------------------------------------------------- +// Block-level vector reduction: BV independent sums across all K threads +// --------------------------------------------------------------------------- + +// Each thread contributes partial[0..BV-1]. After this call, the results +// are written to out[0..BV-1] and are valid in all threads. +// reduce_buf must have at least BV * num_warps floats. +// broadcast_buf must have at least BV floats. +__device__ __forceinline__ void block_reduce_bv( + float partial[BV], + float* reduce_buf, // [num_warps * BV] + float* broadcast_buf, // [BV] + float out[BV] +) { + const int warp_id = threadIdx.x / 32; + const int lane_id = threadIdx.x % 32; + const int num_warps = blockDim.x / 32; + + // Intra-warp reduction for each bv + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + float val = warp_reduce_sum(partial[bv]); + if (lane_id == 0) { + reduce_buf[warp_id * BV + bv] = val; + } + } + __syncthreads(); + + // Inter-warp reduction: threads 0..BV-1 each reduce one bv + if (threadIdx.x < BV) { + float sum = 0.0f; + #pragma unroll 8 + for (int w = 0; w < num_warps; w++) { + sum += reduce_buf[w * BV + threadIdx.x]; + } + broadcast_buf[threadIdx.x] = sum; + } + __syncthreads(); + + // Broadcast to all threads + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + out[bv] = broadcast_buf[bv]; + } +} + +// --------------------------------------------------------------------------- +// Main GDN decode kernel +// --------------------------------------------------------------------------- + +template +__global__ void gdn_decode_kernel( + const T* __restrict__ q_ptr, // [bs, H, K] + const T* __restrict__ k_ptr, // [bs, H, K] + const T* __restrict__ v_ptr, // [bs, HV, V] + const T* __restrict__ a_ptr, // [bs, HV] + const T* __restrict__ b_ptr, // [bs, HV] + const float* __restrict__ A_log_ptr, // [HV] + const float* __restrict__ dt_bias_ptr, // [HV] + float* __restrict__ state_pool, // [pool_size, HV, V, K] + const int64_t* __restrict__ cache_indices, // [bs] + T* __restrict__ output_ptr, // [bs, HV, V] + const int bs, + const int H, // num_k_heads + const int HV, // num_v_heads + const int K, // head_k_dim + const int V, // head_v_dim + const float scale // K^-0.5 +) { + // Block indices + const int bv_block = blockIdx.x; // V-tile index + const int batch_head = blockIdx.y; // batch * HV + const int i_n = batch_head / HV; // batch index + const int i_hv = batch_head % HV; // value head index + const int i_h = i_hv * H / HV; // key head index (GQA mapping) + const int k_idx = threadIdx.x; // K-dimension index + const int v_start = bv_block * BV; // V-dimension start + + if (i_n >= bs) return; + + // Shared memory layout (declared dynamically) + extern __shared__ float smem[]; + const int num_warps = BLOCK_K / 32; + float* sq = smem; // [BLOCK_K] + float* sk = smem + BLOCK_K; // [BLOCK_K] + float* sv_broadcast = smem + 2 * BLOCK_K; // [BV] + float* warp_buf = smem + 2 * BLOCK_K + BV; // [num_warps] + float* reduce_buf = smem + 2 * BLOCK_K + BV + num_warps; // [BV * num_warps] + + // ===== 1. Load gating parameters and compute decay + beta ===== + // All threads load the same scalars (cheap, avoids shared memory) + const float A_log_val = A_log_ptr[i_hv]; + const float dt_bias_val = dt_bias_ptr[i_hv]; + const float a_val = to_float(a_ptr[i_n * HV + i_hv]); + const float b_val = to_float(b_ptr[i_n * HV + i_hv]); + + const float x = a_val + dt_bias_val; + // softplus with numerical stability: softplus(x) = log(1+exp(x)), or x for x>20 + const float softplus_x = (x <= 20.0f) ? logf(1.0f + expf(x)) : x; + const float g = -expf(A_log_val) * softplus_x; + const float decay = expf(g); + const float beta = 1.0f / (1.0f + expf(-b_val)); + + // ===== 2. Load q, k and compute L2 norms ===== + float q_val = 0.0f, k_val = 0.0f; + if (k_idx < K) { + q_val = to_float(q_ptr[i_n * H * K + i_h * K + k_idx]); + k_val = to_float(k_ptr[i_n * H * K + i_h * K + k_idx]); + } + + // L2 norm: reduce q*q and k*k across block + float q_sq_sum = block_reduce_sum(q_val * q_val, warp_buf); + float k_sq_sum = block_reduce_sum(k_val * k_val, warp_buf); + + float q_norm = rsqrtf(q_sq_sum + 1e-6f); + float k_norm = rsqrtf(k_sq_sum + 1e-6f); + + // Store normalized q (scaled) and k in shared memory + if (k_idx < K) { + sq[k_idx] = q_val * q_norm * scale; + sk[k_idx] = k_val * k_norm; + } else { + sq[k_idx] = 0.0f; + sk[k_idx] = 0.0f; + } + __syncthreads(); + + // ===== 3. Load state elements for this thread ===== + const int64_t pool_idx = cache_indices[i_n]; + // state_pool layout: [pool_size, HV, V, K] + const int64_t state_base = pool_idx * HV * V * K + i_hv * V * K; + + float state[BV]; + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + const int v_idx = v_start + bv; + if (v_idx < V && k_idx < K) { + state[bv] = state_pool[state_base + (int64_t)v_idx * K + k_idx]; + } else { + state[bv] = 0.0f; + } + } + + // ===== 4. Decay: state *= exp(g) ===== + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + state[bv] *= decay; + } + + // ===== 5. Delta: v_delta[bv] = v[bv] - sum_k(state[bv,k] * k_norm[k]) ===== + float partial_delta[BV]; + const float my_k = sk[k_idx]; + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + partial_delta[bv] = state[bv] * my_k; + } + + float delta[BV]; + block_reduce_bv(partial_delta, reduce_buf, sv_broadcast, delta); + + // Compute v_delta = (v - delta) * beta and broadcast to all threads. + // Threads 0..BV-1 each load one v element, compute v_delta, write to smem. + if (k_idx < BV) { + const int my_v_idx = v_start + k_idx; + float my_v = (my_v_idx < V) + ? to_float(v_ptr[i_n * HV * V + i_hv * V + my_v_idx]) + : 0.0f; + sv_broadcast[k_idx] = (my_v - delta[k_idx]) * beta; + } + __syncthreads(); + + float v_delta[BV]; + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + v_delta[bv] = sv_broadcast[bv]; + } + + // ===== 6. State update: state[bv,k] += v_delta[bv] * k_norm[k] ===== + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + state[bv] += v_delta[bv] * my_k; + } + + // ===== 7. Output: o[bv] = sum_k(state[bv,k] * q_norm_scaled[k]) ===== + float partial_out[BV]; + const float my_q = sq[k_idx]; + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + partial_out[bv] = state[bv] * my_q; + } + + float out_vals[BV]; + block_reduce_bv(partial_out, reduce_buf, sv_broadcast, out_vals); + + // ===== 8. Store output ===== + // output layout: [bs, HV, V] + if (k_idx < BV) { + const int v_idx = v_start + k_idx; + if (v_idx < V) { + output_ptr[i_n * HV * V + i_hv * V + v_idx] = from_float(out_vals[k_idx]); + } + } + + // ===== 9. Store state back to pool ===== + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + const int v_idx = v_start + bv; + if (v_idx < V && k_idx < K) { + state_pool[state_base + (int64_t)v_idx * K + k_idx] = state[bv]; + } + } +} + +// --------------------------------------------------------------------------- +// Launch wrapper (called via TVM FFI) +// --------------------------------------------------------------------------- + +void run( + tvm::ffi::TensorView q, // [bs, H, K] + tvm::ffi::TensorView k, // [bs, H, K] + tvm::ffi::TensorView v, // [bs, HV, V] + tvm::ffi::TensorView a, // [bs, HV] + tvm::ffi::TensorView b, // [bs, HV] + tvm::ffi::TensorView A_log, // [HV] + tvm::ffi::TensorView dt_bias, // [HV] + tvm::ffi::TensorView state_pool, // [pool_size, HV, V, K] + tvm::ffi::TensorView cache_indices, // [bs] + tvm::ffi::TensorView output // [bs, HV, V] +) { + using namespace mllm_kernel::host; + + // --- Extract dimensions --- + auto BS = SymbolicSize{"bs"}; + auto H_ = SymbolicSize{"H"}; + auto HV_ = SymbolicSize{"HV"}; + auto K_ = SymbolicSize{"K"}; + auto V_ = SymbolicSize{"V"}; + auto PS = SymbolicSize{"pool_size"}; + auto dtype = SymbolicDType{}; + auto device = SymbolicDevice{}; + device.set_options(); + dtype.set_options(); + + (void)TensorMatcher({BS, H_, K_}).with_dtype(dtype).with_device(device).verify(q); + (void)TensorMatcher({BS, H_, K_}).with_dtype(dtype).with_device(device).verify(k); + (void)TensorMatcher({BS, HV_, V_}).with_dtype(dtype).with_device(device).verify(v); + (void)TensorMatcher({BS, HV_}).with_dtype(dtype).with_device(device).verify(a); + (void)TensorMatcher({BS, HV_}).with_dtype(dtype).with_device(device).verify(b); + (void)TensorMatcher({HV_}).with_dtype().with_device(device).verify(A_log); + (void)TensorMatcher({HV_}).with_dtype().with_device(device).verify(dt_bias); + (void)TensorMatcher({PS, HV_, V_, K_}).with_dtype().with_device(device).verify(state_pool); + (void)TensorMatcher({BS}).with_device(device).verify(cache_indices); + (void)TensorMatcher({BS, HV_, V_}).with_dtype(dtype).with_device(device).verify(output); + + const int bs = static_cast(BS.unwrap()); + const int H = static_cast(H_.unwrap()); + const int HV = static_cast(HV_.unwrap()); + const int K = static_cast(K_.unwrap()); + const int V = static_cast(V_.unwrap()); + const float scale = 1.0f / sqrtf(static_cast(K)); + + // Block size = K (rounded up to warp multiple, max 1024) + int block_k = ((K + 31) / 32) * 32; + if (block_k > 1024) block_k = 1024; + const int num_warps = block_k / 32; + + // Grid + const int NV = (V + BV - 1) / BV; + dim3 grid(NV, bs * HV); + dim3 block(block_k); + + // Dynamic shared memory: sq[block_k] + sk[block_k] + sv[BV] + warp_buf[nw] + reduce[BV*nw] + const size_t smem_bytes = (2 * block_k + BV + num_warps + BV * num_warps) * sizeof(float); + + const DLDevice dl_device = device.unwrap(); + + // Typed launch helper + #define LAUNCH_GDN_DECODE(CType, BKVAL) \ + LaunchKernel(grid, block, dl_device, smem_bytes)( \ + gdn_decode_kernel, \ + static_cast(q.data_ptr()), \ + static_cast(k.data_ptr()), \ + static_cast(v.data_ptr()), \ + static_cast(a.data_ptr()), \ + static_cast(b.data_ptr()), \ + static_cast(A_log.data_ptr()), \ + static_cast(dt_bias.data_ptr()), \ + static_cast(state_pool.data_ptr()), \ + static_cast(cache_indices.data_ptr()), \ + static_cast(output.data_ptr()), \ + bs, H, HV, K, V, scale \ + ) + + // Dispatch based on dtype and block size + if (dtype.is_type()) { + if (block_k == 64) { LAUNCH_GDN_DECODE(__nv_bfloat16, 64); } + else if (block_k == 128) { LAUNCH_GDN_DECODE(__nv_bfloat16, 128); } + else if (block_k == 256) { LAUNCH_GDN_DECODE(__nv_bfloat16, 256); } + else { LAUNCH_GDN_DECODE(__nv_bfloat16, 256); } + } else { + if (block_k == 64) { LAUNCH_GDN_DECODE(__half, 64); } + else if (block_k == 128) { LAUNCH_GDN_DECODE(__half, 128); } + else if (block_k == 256) { LAUNCH_GDN_DECODE(__half, 256); } + else { LAUNCH_GDN_DECODE(__half, 256); } + } + + #undef LAUNCH_GDN_DECODE +} + +} // namespace GDNDecodeKernel diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/rms_norm_gated.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/rms_norm_gated.cuh new file mode 100644 index 000000000..b61246029 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/rms_norm_gated.cuh @@ -0,0 +1,212 @@ +// Fused RMSNorm with optional SiLU gating for Qwen3.5 GDN attention. +// +// Computes: output = rmsnorm(x, weight, eps) * silu(z) (if z provided) +// output = rmsnorm(x, weight, eps) (if z is null) +// +// Where: rmsnorm(x) = x / sqrt(mean(x^2) + eps) * weight +// silu(z) = z * sigmoid(z) +// +// This kernel fuses both operations into a single pass over the data, +// maximizing memory bandwidth utilization. Each block processes one row +// (one token position). +// +// Supported dtypes: float16, bfloat16 (accumulation in float32). + +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace RMSNormGatedKernel { + +// --------------------------------------------------------------------------- +// Warp-level reduction +// --------------------------------------------------------------------------- + +__device__ __forceinline__ float warp_reduce_sum(float val) { + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_xor_sync(0xffffffff, val, offset); + } + return val; +} + +// --------------------------------------------------------------------------- +// Type conversion helpers +// --------------------------------------------------------------------------- + +template +__device__ __forceinline__ float to_float(T val); + +template <> +__device__ __forceinline__ float to_float(half val) { + return __half2float(val); +} + +template <> +__device__ __forceinline__ float to_float<__nv_bfloat16>(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__device__ __forceinline__ float to_float(float val) { + return val; +} + +template +__device__ __forceinline__ T from_float(float val); + +template <> +__device__ __forceinline__ half from_float(float val) { + return __float2half(val); +} + +template <> +__device__ __forceinline__ __nv_bfloat16 from_float<__nv_bfloat16>(float val) { + return __float2bfloat16(val); +} + +template <> +__device__ __forceinline__ float from_float(float val) { + return val; +} + +// --------------------------------------------------------------------------- +// Main kernel +// --------------------------------------------------------------------------- + +template +__global__ void rms_norm_gated_kernel( + T* __restrict__ output, // [M, N] + const T* __restrict__ input, // [M, N] + const T* __restrict__ weight, // [N] + const T* __restrict__ gate, // [M, N] or nullptr + const int M, // number of rows + const int N, // number of columns (hidden_size) + const float eps +) { + const int row = blockIdx.x; + if (row >= M) return; + + const int tid = threadIdx.x; + const T* x_row = input + row * N; + T* out_row = output + row * N; + const T* z_row = (gate != nullptr) ? gate + row * N : nullptr; + + // --- Pass 1: compute sum of squares --- + float sum_sq = 0.0f; + for (int col = tid; col < N; col += BLOCK_SIZE) { + float val = to_float(x_row[col]); + sum_sq += val * val; + } + + // Block-level reduction + __shared__ float shared_sum[32]; // one per warp + int warp_id = tid / 32; + int lane_id = tid % 32; + + sum_sq = warp_reduce_sum(sum_sq); + if (lane_id == 0) { + shared_sum[warp_id] = sum_sq; + } + __syncthreads(); + + // Final reduction in first warp + if (warp_id == 0) { + float val = (lane_id < (BLOCK_SIZE / 32)) ? shared_sum[lane_id] : 0.0f; + val = warp_reduce_sum(val); + if (lane_id == 0) { + shared_sum[0] = val; + } + } + __syncthreads(); + + float rms = rsqrtf(shared_sum[0] / (float)N + eps); + + // --- Pass 2: normalize, scale by weight, optionally gate with silu(z) --- + for (int col = tid; col < N; col += BLOCK_SIZE) { + float val = to_float(x_row[col]); + float w = to_float(weight[col]); + + float normed = val * rms * w; + + if (z_row != nullptr) { + float z = to_float(z_row[col]); + // silu(z) = z * sigmoid(z) + float silu_z = z / (1.0f + expf(-z)); + normed *= silu_z; + } + + out_row[col] = from_float(normed); + } +} + +// --------------------------------------------------------------------------- +// Launch wrapper (called via TVM FFI) +// --------------------------------------------------------------------------- + +void run( + tvm::ffi::TensorView output, + tvm::ffi::TensorView input, + tvm::ffi::TensorView weight, + tvm::ffi::TensorView gate, // empty tensor (numel==0) means no gate + double eps +) { + using namespace mllm_kernel::host; + + auto M = SymbolicSize{"M"}; + auto N = SymbolicSize{"N"}; + auto dtype = SymbolicDType{}; + auto device = SymbolicDevice{}; + device.set_options(); + dtype.set_options(); + + (void)TensorMatcher({M, N}).with_dtype(dtype).with_device(device).verify(input); + (void)TensorMatcher({M, N}).with_dtype(dtype).with_device(device).verify(output); + (void)TensorMatcher({N}).with_dtype(dtype).with_device(device).verify(weight); + + const int rows = static_cast(M.unwrap()); + const int cols = static_cast(N.unwrap()); + const bool has_gate = (gate.numel() > 0); + + constexpr int BLOCK_SIZE = 256; + + if (dtype.is_type()) { + LaunchKernel(rows, BLOCK_SIZE, device.unwrap())( + rms_norm_gated_kernel, + static_cast(output.data_ptr()), + static_cast(input.data_ptr()), + static_cast(weight.data_ptr()), + has_gate ? static_cast(gate.data_ptr()) : nullptr, + rows, cols, static_cast(eps) + ); + } else if (dtype.is_type()) { + LaunchKernel(rows, BLOCK_SIZE, device.unwrap())( + rms_norm_gated_kernel<__nv_bfloat16, BLOCK_SIZE>, + static_cast<__nv_bfloat16*>(output.data_ptr()), + static_cast(input.data_ptr()), + static_cast(weight.data_ptr()), + has_gate ? static_cast(gate.data_ptr()) : nullptr, + rows, cols, static_cast(eps) + ); + } else { + LaunchKernel(rows, BLOCK_SIZE, device.unwrap())( + rms_norm_gated_kernel, + static_cast(output.data_ptr()), + static_cast(input.data_ptr()), + static_cast(weight.data_ptr()), + has_gate ? static_cast(gate.data_ptr()) : nullptr, + rows, cols, static_cast(eps) + ); + } +} + +} // namespace RMSNormGatedKernel diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/store_cache.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/store_cache.cuh new file mode 100644 index 000000000..05daabee0 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/store_cache.cuh @@ -0,0 +1,202 @@ +// Copyright SGLang Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Store KV cache kernel: efficiently scatter key/value tensors into a +// pre-allocated KV cache pool using warp-level vectorized copies. +// +// Reference: sglang jit_kernel/csrc/elementwise/kvcache.cuh + +#pragma once + +#include +#include +#include + +#include +#include + +#include + +namespace { + +// ─────────────────────────────────────────────────────────────── +// Parameter block passed to the kernel via __grid_constant__ +// ─────────────────────────────────────────────────────────────── + +struct StoreKVCacheParams { + const void* __restrict__ k; + const void* __restrict__ v; + void* __restrict__ k_cache; + void* __restrict__ v_cache; + const void* __restrict__ indices; + int64_t stride_k_bytes; + int64_t stride_v_bytes; + int64_t stride_cache_bytes; + int64_t stride_indices; + uint32_t batch_size; +}; + +constexpr uint32_t kNumWarps = 4; +constexpr uint32_t kThreadsPerBlock = kNumWarps * device::kWarpThreads; + +// ─────────────────────────────────────────────────────────────── +// Vectorized warp-level KV copy +// ─────────────────────────────────────────────────────────────── +// +// Each warp copies kElementBytes of K data and kElementBytes of V +// data using the widest possible aligned vector type (uint4 = 16B, +// uint2 = 8B, or uint32_t = 4B). + +namespace detail { + +template +__device__ __forceinline__ void warp_copy_bytes(const void* __restrict__ src, void* __restrict__ dst, int64_t num_vecs) { + const int lane = threadIdx.x % device::kWarpThreads; + const auto* s = static_cast(src); + auto* d = static_cast(dst); + for (int64_t i = lane; i < num_vecs; i += device::kWarpThreads) { d[i] = s[i]; } +} + +} // namespace detail + +template +__device__ __forceinline__ void copy_kv_warp(const void* __restrict__ k_src, const void* __restrict__ v_src, + void* __restrict__ k_dst, void* __restrict__ v_dst) { + static_assert(kElementBytes > 0 && kElementBytes % 4 == 0, "Element size must be a positive multiple of 4 bytes"); + + // Pick the widest aligned vector type the element size supports. + if constexpr (kElementBytes % 16 == 0) { + constexpr int64_t N = kElementBytes / 16; + detail::warp_copy_bytes(k_src, k_dst, N); + detail::warp_copy_bytes(v_src, v_dst, N); + } else if constexpr (kElementBytes % 8 == 0) { + constexpr int64_t N = kElementBytes / 8; + detail::warp_copy_bytes(k_src, k_dst, N); + detail::warp_copy_bytes(v_src, v_dst, N); + } else { + constexpr int64_t N = kElementBytes / 4; + detail::warp_copy_bytes(k_src, k_dst, N); + detail::warp_copy_bytes(v_src, v_dst, N); + } +} + +// ─────────────────────────────────────────────────────────────── +// Main kernel +// ─────────────────────────────────────────────────────────────── +// +// Template parameters: +// kElementBytes total bytes per token row (head_num * head_dim * dtype_size) +// kSplit how many warps collaborate on one element (1, 2, or 4) +// kUsePDL whether to emit PDL synchronisation instructions +// T index dtype (int32_t or int64_t) + +template +__global__ void store_kvcache(const __grid_constant__ StoreKVCacheParams params) { + using namespace device; + constexpr auto kSplitSize = kElementBytes / kSplit; + + const uint32_t warp_id = blockIdx.x * kNumWarps + threadIdx.x / kWarpThreads; + const uint32_t item_id = warp_id / kSplit; + const uint32_t split_id = warp_id % kSplit; + + const auto& [k_input, v_input, k_cache, v_cache, indices, stride_k, stride_v, stride_cache, stride_indices, batch_size] = + params; + + if (item_id >= batch_size) return; + + const auto index_ptr = static_cast(indices) + item_id * stride_indices; + PDLWaitPrimary(); + + const auto index = *index_ptr; + const auto k_src = pointer::offset(k_input, item_id * stride_k, split_id * kSplitSize); + const auto v_src = pointer::offset(v_input, item_id * stride_v, split_id * kSplitSize); + const auto k_dst = pointer::offset(k_cache, index * stride_cache, split_id * kSplitSize); + const auto v_dst = pointer::offset(v_cache, index * stride_cache, split_id * kSplitSize); + + copy_kv_warp(k_src, v_src, k_dst, v_dst); + PDLTriggerSecondary(); +} + +template +struct StoreKVCacheKernel { + static_assert(kElementBytes > 0 && kElementBytes % 4 == 0); + + template + static constexpr auto store_kernel = store_kvcache; + + template + static auto get_kernel(int num_split) { + using namespace mllm_kernel::host; + if constexpr (kElementBytes % (4 * 128) == 0) { + if (num_split == 4) return store_kernel<4, T>; + } + if constexpr (kElementBytes % (2 * 128) == 0) { + if (num_split == 2) return store_kernel<2, T>; + } + if (num_split == 1) return store_kernel<1, T>; + Panic("Unsupported num_split ", num_split, " for element size ", kElementBytes); + } + + static void run(tvm::ffi::TensorView k, tvm::ffi::TensorView v, tvm::ffi::TensorView k_cache, tvm::ffi::TensorView v_cache, + tvm::ffi::TensorView indices, int num_split) { + using namespace mllm_kernel::host; + + auto B = SymbolicSize{"batch_size"}; + auto D = SymbolicSize{"element_size"}; + auto KS = SymbolicSize{"k_stride"}; + auto VS = SymbolicSize{"v_stride"}; + auto S = SymbolicSize{"cache_stride"}; + auto I = SymbolicSize{"indices_stride"}; + auto dtype = SymbolicDType{}; + auto device = SymbolicDevice{}; + auto indice_dtype = SymbolicDType{}; + device.set_options(); + + // k, v: [B, D] with strides [KS, 1] + (void)TensorMatcher({B, D}).with_strides({KS, 1}).with_dtype(dtype).with_device(device).verify(k); + (void)TensorMatcher({B, D}).with_strides({VS, 1}).with_dtype(dtype).with_device(device).verify(v); + + // k_cache, v_cache: [*, D] with strides [S, 1] + (void)TensorMatcher({-1, D}).with_strides({S, 1}).with_dtype(dtype).with_device(device).verify(k_cache).verify(v_cache); + + // indices: [B] with strides [I] + (void)TensorMatcher({B}).with_strides({I}).with_dtype(indice_dtype).with_device(device).verify(indices); + + const int64_t dtype_size = dtype_bytes(dtype.unwrap()); + const uint32_t num_elements = static_cast(B.unwrap()); + RuntimeCheck(kElementBytes == dtype_size * D.unwrap(), "Element size mismatch: expected ", kElementBytes, " but got ", + dtype_size * D.unwrap()); + + const auto params = StoreKVCacheParams{ + .k = k.data_ptr(), + .v = v.data_ptr(), + .k_cache = k_cache.data_ptr(), + .v_cache = v_cache.data_ptr(), + .indices = indices.data_ptr(), + .stride_k_bytes = KS.unwrap() * dtype_size, + .stride_v_bytes = VS.unwrap() * dtype_size, + .stride_cache_bytes = S.unwrap() * dtype_size, + .stride_indices = I.unwrap(), + .batch_size = num_elements, + }; + + const auto use_int32 = indice_dtype.is_type(); + const auto kernel = use_int32 ? get_kernel(num_split) : get_kernel(num_split); + const auto num_blocks = div_ceil(num_elements * num_split, kNumWarps); + + LaunchKernel(num_blocks, kThreadsPerBlock, device.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/mllm-kernel/requirements.txt b/mllm-kernel/mllm_kernel/cuda/csrc/vocab_embedding.cuh similarity index 100% rename from mllm-kernel/requirements.txt rename to mllm-kernel/mllm_kernel/cuda/csrc/vocab_embedding.cuh diff --git a/mllm-kernel/mllm_kernel/cuda/jit/__init__.py b/mllm-kernel/mllm_kernel/cuda/jit/__init__.py index 696e73ea0..cc4ab667a 100644 --- a/mllm-kernel/mllm_kernel/cuda/jit/__init__.py +++ b/mllm-kernel/mllm_kernel/cuda/jit/__init__.py @@ -1,3 +1,5 @@ from .add_constant import add_constant +from .gdn_decode import gdn_decode +from .store_cache import can_use_store_cache, store_cache -__all__ = ["add_constant"] +__all__ = ["add_constant", "can_use_store_cache", "gdn_decode", "store_cache"] diff --git a/mllm-kernel/mllm_kernel/cuda/jit/create_kv_indices.py b/mllm-kernel/mllm_kernel/cuda/jit/create_kv_indices.py new file mode 100644 index 000000000..565686a40 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/jit/create_kv_indices.py @@ -0,0 +1,118 @@ +"""High-performance CUDA JIT wrapper for create_kv_indices. + +This module exposes a single function: + + create_kv_indices(req_to_token, req_pool_indices, + page_kernel_lens, kv_indptr, + kv_start_idx, kv_indices) + +which is a Python binding around the C++/CUDA kernel defined in +`mllm_kernel/cuda/csrc/create_kv_indices.cuh`. + +The kernel transforms pymllm's 2-D ReqToTokenPool mapping table into the flat +`(kv_indptr, kv_indices)` layout expected by FlashInfer's paged KV attention +wrappers. It is carefully written for maximum throughput and is intended to +replace the Triton implementation `_create_kv_indices_triton` in +`pymllm.layers.attention.flashinfer_backend`. +""" + +from __future__ import annotations + +import torch + +from mllm_kernel.jit_utils import cache_once, jit + + +@cache_once +def _make_create_kv_indices_kernel(): + """JIT-compile the CUDA kernel and return a callable wrapper. + + The JIT system will: + * locate `create_kv_indices.cuh` under the mllm-kernel CUDA csrc tree, + * compile it into a TVM FFI module, + * expose `CreateKvIndicesKernel::run` as `compiled_module.create_kv_indices`. + """ + + @jit( + args=[], + device="cuda", + cuda_files=["create_kv_indices.cuh"], + cpp_wrappers=[], + cuda_wrappers=[ + ("create_kv_indices", "CreateKvIndicesKernel::run"), + ], + func_name="create_kv_indices", + ) + def _kernel( + compiled_module, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + page_kernel_lens: torch.Tensor, + kv_indptr: torch.Tensor, + kv_start_idx: torch.Tensor, + kv_indices: torch.Tensor, + ) -> None: + compiled_module.create_kv_indices( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) + + return _kernel + + +def create_kv_indices( + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + page_kernel_lens: torch.Tensor, + kv_indptr: torch.Tensor, + kv_start_idx: torch.Tensor | None, + kv_indices: torch.Tensor, +) -> None: + """Fill a flat KV-index buffer from the ReqToTokenPool mapping. + + This is a thin Python wrapper that forwards to the JIT-compiled CUDA + kernel. All tensors must be placed on the same CUDA device. + + Args + ---- + req_to_token: + Mapping tensor from ReqToTokenPool, shape + ``[max_reqs, max_context_len]``, dtype ``torch.int32``. + req_pool_indices: + Request slots participating in this batch, shape ``[batch_size]``, + dtype ``torch.int32``. + page_kernel_lens: + Per-sequence token counts (how many tokens to attend), shape + ``[batch_size]``, dtype ``torch.int32``. + kv_indptr: + Prefix sums over per-sequence token counts, shape ``[batch_size + 1]``, + dtype ``torch.int32``. ``kv_indptr[i]`` is the starting offset in + ``kv_indices`` for sequence ``i``. + kv_start_idx: + Optional starting positions inside each sequence, shape + ``[batch_size]`` or ``[0]``, dtype ``torch.int32``. When + ``None``, the kernel assumes 0 for all sequences. + kv_indices: + Output flat KV-index buffer, shape ``[N]``, dtype ``torch.int32``. + ``N`` must be at least ``kv_indptr[batch_size]``. + """ + if kv_start_idx is None: + # Use an empty tensor to signal "no start offsets". The C++ launcher + # treats length==0 as "no kv_start" and will pass a nullptr into the + # parameter block, which is slightly cheaper than materialising a + # full zero tensor on every call. + kv_start_idx = req_pool_indices.new_empty(0, dtype=torch.int32) + + kernel = _make_create_kv_indices_kernel() + kernel( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) diff --git a/mllm-kernel/mllm_kernel/cuda/jit/gdn_decode.py b/mllm-kernel/mllm_kernel/cuda/jit/gdn_decode.py new file mode 100644 index 000000000..53aaeaab3 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/jit/gdn_decode.py @@ -0,0 +1,114 @@ +"""Fused GDN decode CUDA JIT kernel. + +Performs a single-token GDN (Gated Delta Net) recurrent update per request, +fusing gating + L2 normalization + delta rule + output computation into +one kernel. Works on SM80+ (Ampere, Jetson Orin, Hopper, ...). + +Usage:: + + from mllm_kernel.cuda.jit.gdn_decode import gdn_decode + + output = gdn_decode(q, k, v, a, b, A_log, dt_bias, state_pool, cache_indices) +""" + +from __future__ import annotations + +import torch + +from mllm_kernel.jit_utils import cache_once, jit + + +@cache_once +def _make_gdn_decode_kernel(): + """JIT-compile the fused GDN decode CUDA kernel.""" + + @jit( + args=[], + device="cuda", + cuda_files=["gdn_decode.cuh"], + cpp_wrappers=[], + cuda_wrappers=[ + ("gdn_decode", "GDNDecodeKernel::run"), + ], + func_name="gdn_decode", + ) + def _kernel( + compiled_module, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + state_pool: torch.Tensor, + cache_indices: torch.Tensor, + output: torch.Tensor, + ) -> None: + compiled_module.gdn_decode( + q, k, v, a, b, A_log, dt_bias, state_pool, cache_indices, output + ) + + return _kernel + + +def gdn_decode( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + state_pool: torch.Tensor, + cache_indices: torch.Tensor, +) -> torch.Tensor: + """Fused GDN decode: gating + L2 norm + delta rule + output. + + Parameters + ---------- + q : torch.Tensor + Query tensor, shape ``(bs, num_k_heads, head_k_dim)``, bf16/fp16. + k : torch.Tensor + Key tensor, shape ``(bs, num_k_heads, head_k_dim)``, bf16/fp16. + v : torch.Tensor + Value tensor, shape ``(bs, num_v_heads, head_v_dim)``, bf16/fp16. + a : torch.Tensor + Decay gate input, shape ``(bs, num_v_heads)``, bf16/fp16. + b : torch.Tensor + Update gate input, shape ``(bs, num_v_heads)``, bf16/fp16. + A_log : torch.Tensor + Log-space decay parameter, shape ``(num_v_heads,)``, float32. + dt_bias : torch.Tensor + Bias for decay gate, shape ``(num_v_heads,)``, float32. + state_pool : torch.Tensor + Pooled recurrent state, shape ``(pool_size, num_v_heads, head_v_dim, head_k_dim)``, + float32. Modified in-place. + cache_indices : torch.Tensor + Pool indices per request, shape ``(bs,)``, int64. + + Returns + ------- + torch.Tensor + Output tensor, shape ``(bs, num_v_heads, head_v_dim)``, same dtype as v. + """ + bs = q.shape[0] + num_v_heads = v.shape[1] + head_v_dim = v.shape[2] + + output = torch.empty(bs, num_v_heads, head_v_dim, dtype=v.dtype, device=v.device) + + kernel = _make_gdn_decode_kernel() + kernel( + q.contiguous(), + k.contiguous(), + v.contiguous(), + a.contiguous(), + b.contiguous(), + A_log.contiguous(), + dt_bias.contiguous(), + state_pool, + cache_indices.to(torch.int64).contiguous(), + output, + ) + return output diff --git a/mllm-kernel/mllm_kernel/cuda/jit/rms_norm_gated.py b/mllm-kernel/mllm_kernel/cuda/jit/rms_norm_gated.py new file mode 100644 index 000000000..d7906a383 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/jit/rms_norm_gated.py @@ -0,0 +1,87 @@ +"""Fused RMSNorm + SiLU gating CUDA JIT kernel for Qwen3.5 GDN attention. + +Computes ``rmsnorm(x, weight, eps) * silu(z)`` in a single fused pass. + +Usage:: + + from mllm_kernel.cuda.jit.rms_norm_gated import rms_norm_gated + + output = rms_norm_gated(x, weight, z=gate, eps=1e-6) +""" + +from __future__ import annotations + +import torch + +from mllm_kernel.jit_utils import cache_once, jit + + +@cache_once +def _make_rms_norm_gated_kernel(): + """JIT-compile the fused RMSNorm+gating CUDA kernel.""" + + @jit( + args=[], + device="cuda", + cuda_files=["rms_norm_gated.cuh"], + cpp_wrappers=[], + cuda_wrappers=[ + ("rms_norm_gated", "RMSNormGatedKernel::run"), + ], + func_name="rms_norm_gated", + ) + def _kernel( + compiled_module, + output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + gate: torch.Tensor, + eps: float, + ) -> None: + compiled_module.rms_norm_gated(output, input, weight, gate, eps) + + return _kernel + + +def rms_norm_gated( + x: torch.Tensor, + weight: torch.Tensor, + z: torch.Tensor | None = None, + eps: float = 1e-6, +) -> torch.Tensor: + """Fused RMSNorm with optional SiLU gating. + + Parameters + ---------- + x : torch.Tensor + Input tensor, shape ``(M, N)`` or ``(..., N)``. + weight : torch.Tensor + Normalization weight, shape ``(N,)``. + z : torch.Tensor or None + Optional gating tensor, same shape as ``x``. + If provided: ``output = rmsnorm(x) * silu(z)`` + eps : float + Epsilon for numerical stability. + + Returns + ------- + torch.Tensor + Output with same shape and dtype as ``x``. + """ + x_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + if z is not None: + z_2d = z.reshape(-1, z.shape[-1]) + if z_2d.stride(-1) != 1: + z_2d = z_2d.contiguous() + else: + z_2d = x.new_empty(0) # empty tensor signals "no gate" to the kernel + + if x_2d.stride(-1) != 1: + x_2d = x_2d.contiguous() + + output = torch.empty_like(x_2d) + kernel = _make_rms_norm_gated_kernel() + kernel(output, x_2d, weight.contiguous(), z_2d, eps) + return output.reshape(x_shape) diff --git a/mllm-kernel/mllm_kernel/cuda/jit/store_cache.py b/mllm-kernel/mllm_kernel/cuda/jit/store_cache.py new file mode 100644 index 000000000..96a73f5ef --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/jit/store_cache.py @@ -0,0 +1,127 @@ +# Copyright (c) MLLM Team. +# Licensed under the MIT License. +# +# Python interface for the store_cache CUDA kernel. +# Efficiently scatters key/value tensors into a pre-allocated KV cache pool. + +from __future__ import annotations + +import logging +import torch +from mllm_kernel.jit_utils import jit +from mllm_kernel.jit_utils.compile import cache_once, make_cpp_args + + +logger = logging.getLogger(__name__) + + +@cache_once +def _is_arch_support_pdl() -> bool: + if not torch.cuda.is_available(): + return False + major, minor = torch.cuda.get_device_capability() + # PDL requires sm_90a (Hopper) or later + return major > 9 or (major == 9 and minor >= 0) + + +def _make_store_cache_kernel(row_bytes: int): + """Create a JIT-compiled store_cache kernel for the given row_bytes.""" + pdl = _is_arch_support_pdl() + cpp_args = make_cpp_args(row_bytes, pdl) + + @jit( + args=[row_bytes, pdl], + device="cuda", + cuda_files=["store_cache.cuh"], + cpp_wrappers=[], + cuda_wrappers=[ + ("store_cache", f"StoreKVCacheKernel<{cpp_args}>::run"), + ], + func_name="store_cache", + ) + def _kernel( + compiled_module, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, + num_split: int, + ) -> None: + compiled_module.store_cache(k, v, k_cache, v_cache, indices, num_split) + + return _kernel + + +_KERNEL_CACHE: dict[int, object] = {} + + +def _get_kernel(row_bytes: int): + if row_bytes not in _KERNEL_CACHE: + _KERNEL_CACHE[row_bytes] = _make_store_cache_kernel(row_bytes) + return _KERNEL_CACHE[row_bytes] + + +@cache_once +def can_use_store_cache(row_bytes: int) -> bool: + """Check whether the JIT store_cache kernel supports the given row size. + + Returns ``False`` if *row_bytes* is not a multiple of 4 or if the JIT + compilation fails for any reason. + """ + if row_bytes % 4 != 0: + logger.warning( + "Unsupported row_bytes=%d for JIT store_cache kernel: " + "must be multiple of 4", + row_bytes, + ) + return False + try: + _get_kernel(row_bytes) + return True + except Exception as e: + logger.warning( + "Failed to load JIT store_cache kernel with row_bytes=%d: %s", + row_bytes, + e, + ) + return False + + +def store_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, + *, + row_bytes: int = 0, + num_split: int = 0, +) -> None: + """Store key and value tensors into a KV cache at specified indices. + + Each row of *k* (and *v*) is scattered into *k_cache* (and *v_cache*) + at the location given by the corresponding entry in *indices*. + + Args: + k: Key tensor, shape ``(batch_size, head_num * head_dim)``. + v: Value tensor, shape ``(batch_size, head_num * head_dim)``. + k_cache: Key cache, shape ``(num_slots, head_num * head_dim)``. + v_cache: Value cache, shape ``(num_slots, head_num * head_dim)``. + indices: Index tensor, shape ``(batch_size,)``, dtype int32 or int64. + row_bytes: Bytes per row. Auto-detected from *k* when 0. + num_split: Number of warps that cooperate on each element (1, 2, or 4). + When 0 the best value is chosen automatically based on alignment. + """ + row_bytes = row_bytes or k.shape[-1] * k.element_size() + kernel = _get_kernel(row_bytes) + + if num_split <= 0: + if row_bytes % 2048 == 0: + num_split = 4 + elif row_bytes % 1024 == 0: + num_split = 2 + else: + num_split = 1 + + kernel(k, v, k_cache, v_cache, indices, num_split) diff --git a/mllm-kernel/pyproject.toml b/mllm-kernel/pyproject.toml index f64e1306e..13147f068 100644 --- a/mllm-kernel/pyproject.toml +++ b/mllm-kernel/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "packaging", "torch", "torch-c-dlpack-ext", - "apache-tvm-ffi", + "apache-tvm-ffi == 0.1.8.post2", ] [project.optional-dependencies] @@ -27,6 +27,9 @@ dev = [ "pytest-html", ] +[project.scripts] +mllm-kernel = "mllm_kernel.__main__:main" + [tool.scikit-build] # Build configuration wheel.py-api = "py3" @@ -52,7 +55,7 @@ logging.level = "INFO" # Wheel configuration - include the Python package wheel.packages = ["mllm_kernel"] -wheel.install-dir = "mllm_kernel" +wheel.install-dir = "" # Install directories for cmake targets wheel.cmake = true diff --git a/mllm-kernel/tests/test_create_kv_indices.py b/mllm-kernel/tests/test_create_kv_indices.py new file mode 100644 index 000000000..e8bf770a3 --- /dev/null +++ b/mllm-kernel/tests/test_create_kv_indices.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import pytest +import torch + +from mllm_kernel.cuda.jit.create_kv_indices import create_kv_indices + + +def _make_batch( + *, + max_reqs: int, + max_ctx: int, + batch_size: int, + use_start_offsets: bool, + seed: int = 0, +): + """Construct a random-but-bounded test batch for create_kv_indices. + + The constraints ensure that for every sequence i: + 0 <= kv_start_idx[i] + 0 < page_kernel_lens[i] + kv_start_idx[i] + page_kernel_lens[i] <= max_ctx + so the kernel never reads beyond the ReqToTokenPool row. + """ + # Use a CUDA generator for randperm (which requires matching device) + # and a separate CPU generator for randint (which only accepts CPU). + g_cuda = torch.Generator(device="cuda").manual_seed(seed) + g_cpu = torch.Generator(device="cpu").manual_seed(seed) + + device = "cuda" + # req_to_token[req_slot, position] -> kv_index (here we simply use a + # monotonically increasing pattern so correctness is easy to check). + req_to_token = torch.arange( + max_reqs * max_ctx, dtype=torch.int32, device=device + ).reshape(max_reqs, max_ctx) + + # Sample distinct request slots for the batch. + assert batch_size <= max_reqs + req_pool_indices = torch.randperm(max_reqs, generator=g_cuda, device=device)[ + :batch_size + ].to(torch.int32) + + # For each sequence choose a valid (start, length) pair. + page_kernel_lens_list = [] + kv_start_idx_list = [] + for _ in range(batch_size): + # ensure at least 1 token per sequence + L = int(torch.randint(1, max_ctx, (1,), generator=g_cpu).item()) + if use_start_offsets: + start_max = max_ctx - L + start = int(torch.randint(0, max(start_max, 1), (1,), generator=g_cpu).item()) + else: + start = 0 + page_kernel_lens_list.append(L) + kv_start_idx_list.append(start) + + page_kernel_lens = torch.tensor( + page_kernel_lens_list, dtype=torch.int32, device=device + ) + kv_start_idx = torch.tensor(kv_start_idx_list, dtype=torch.int32, device=device) + + # Build kv_indptr prefix sums. + kv_indptr = torch.empty(batch_size + 1, dtype=torch.int32, device=device) + kv_indptr[0] = 0 + kv_indptr[1:] = torch.cumsum(page_kernel_lens, dim=0) + + kv_indices = torch.empty( + int(kv_indptr[-1].item()), dtype=torch.int32, device=device + ) + + return ( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +@pytest.mark.parametrize("use_start_offsets", [False, True]) +@pytest.mark.parametrize( + "batch_size,max_reqs,max_ctx", + [ + (1, 4, 16), # minimal batch + (4, 8, 64), # small batch + (32, 64, 512), # medium batch, longer context + (128, 256, 2048), # larger batch, stress inner loop + ], +) +def test_create_kv_indices_matches_reference( + use_start_offsets: bool, + batch_size: int, + max_reqs: int, + max_ctx: int, +): + """create_kv_indices must match a naive PyTorch reference implementation. + + The reference is computed on CPU using explicit loops over + (request_slot, start, length); the CUDA kernel must produce identical + flat kv_indices for the same inputs. + """ + ( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) = _make_batch( + max_reqs=max_reqs, + max_ctx=max_ctx, + batch_size=batch_size, + use_start_offsets=use_start_offsets, + seed=2026, + ) + + # Call CUDA kernel (kv_start_idx can be None to exercise that path). + create_kv_indices( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx if use_start_offsets else None, + kv_indices, + ) + torch.cuda.synchronize() + + # Naive reference on CPU. + req_to_token_cpu = req_to_token.cpu() + req_pool_indices_cpu = req_pool_indices.cpu().to(torch.long) + page_kernel_lens_cpu = page_kernel_lens.cpu() + kv_start_idx_cpu = kv_start_idx.cpu() + + ref_segments = [] + for i in range(batch_size): + req = req_pool_indices_cpu[i].item() + start = kv_start_idx_cpu[i].item() if use_start_offsets else 0 + L = page_kernel_lens_cpu[i].item() + row = req_to_token_cpu[req, start : start + L] + ref_segments.append(row) + ref = torch.cat(ref_segments, dim=0) + + assert kv_indices.shape == ref.shape + assert torch.equal(kv_indices.cpu(), ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_single_token_per_sequence(): + """Each sequence has exactly 1 token — exercises the minimal-work path.""" + device = "cuda" + bs = 8 + max_ctx = 32 + req_to_token = torch.arange(bs * max_ctx, dtype=torch.int32, device=device).reshape(bs, max_ctx) + req_pool_indices = torch.arange(bs, dtype=torch.int32, device=device) + page_kernel_lens = torch.ones(bs, dtype=torch.int32, device=device) + kv_indptr = torch.arange(bs + 1, dtype=torch.int32, device=device) + kv_indices = torch.empty(bs, dtype=torch.int32, device=device) + + create_kv_indices(req_to_token, req_pool_indices, page_kernel_lens, kv_indptr, None, kv_indices) + torch.cuda.synchronize() + + # Each sequence contributes req_to_token[i, 0]. + expected = req_to_token[:, 0] + assert torch.equal(kv_indices, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_oversized_output_buffer(): + """kv_indices buffer is larger than needed (prefill path uses +256 padding).""" + device = "cuda" + bs = 4 + max_ctx = 64 + req_to_token = torch.arange(bs * max_ctx, dtype=torch.int32, device=device).reshape(bs, max_ctx) + req_pool_indices = torch.arange(bs, dtype=torch.int32, device=device) + page_kernel_lens = torch.full((bs,), 10, dtype=torch.int32, device=device) + kv_indptr = torch.arange(0, bs * 10 + 1, 10, dtype=torch.int32, device=device) + # Allocate with extra padding, like the prefill path does. + kv_indices = torch.full((bs * 10 + 256,), -1, dtype=torch.int32, device=device) + + create_kv_indices(req_to_token, req_pool_indices, page_kernel_lens, kv_indptr, None, kv_indices) + torch.cuda.synchronize() + + # First bs*10 entries should match; padding should remain -1. + ref_segments = [] + for i in range(bs): + ref_segments.append(req_to_token[i, :10]) + ref = torch.cat(ref_segments, dim=0) + assert torch.equal(kv_indices[:bs * 10], ref) + assert torch.all(kv_indices[bs * 10:] == -1) diff --git a/mllm-kernel/tests/test_store_cache.py b/mllm-kernel/tests/test_store_cache.py new file mode 100644 index 000000000..5e4f1bcc3 --- /dev/null +++ b/mllm-kernel/tests/test_store_cache.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import pytest +import torch + +from mllm_kernel.cuda.jit import can_use_store_cache, store_cache + + +def _make_inputs( + *, + batch_size: int, + num_slots: int, + row_dim: int, + dtype: torch.dtype, + index_dtype: torch.dtype, + seed: int = 0, +): + torch.manual_seed(seed) + device = "cuda" + k = torch.randn(batch_size, row_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, row_dim, device=device, dtype=dtype) + # Use unique indices to avoid write conflicts on the same cache slot. + indices = torch.randperm(num_slots, device=device)[:batch_size].to(index_dtype) + k_cache = torch.zeros(num_slots, row_dim, device=device, dtype=dtype) + v_cache = torch.zeros_like(k_cache) + return k, v, k_cache, v_cache, indices + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64]) +def test_store_cache_matches_torch_index(dtype: torch.dtype, index_dtype: torch.dtype): + batch_size = 257 + num_slots = 4096 + row_dim = 8 * 128 # 1024 -> fp16 row_bytes=2048 + row_bytes = row_dim * torch.tensor([], dtype=dtype).element_size() + + assert can_use_store_cache(row_bytes), f"store_cache unavailable for row_bytes={row_bytes}" + + k, v, k_cache, v_cache, indices = _make_inputs( + batch_size=batch_size, + num_slots=num_slots, + row_dim=row_dim, + dtype=dtype, + index_dtype=index_dtype, + seed=2026, + ) + + k_ref = k_cache.clone() + v_ref = v_cache.clone() + k_ref[indices] = k + v_ref[indices] = v + + store_cache(k, v, k_cache, v_cache, indices) + torch.cuda.synchronize() + + assert torch.equal(k_cache, k_ref) + assert torch.equal(v_cache, v_ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_can_use_store_cache_rejects_invalid_row_bytes(): + assert not can_use_store_cache(2) + assert not can_use_store_cache(6) + assert can_use_store_cache(4) + diff --git a/mllm/ffi/Extension.cc b/mllm/ffi/Extension.cc index cb999191d..f3f2d2488 100644 --- a/mllm/ffi/Extension.cc +++ b/mllm/ffi/Extension.cc @@ -83,12 +83,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Tensor related refl::GlobalDef().def("mllm.empty", mllm::ffi::empty); refl::GlobalDef().def("mllm.from_torch", [](const tvm::ffi::Tensor& t) -> mllm::ffi::Tensor { - auto dl_pack = t.get()->ToDLPack(); + auto dl_pack = t.ToDLPack(); return ::mllm::ffi::Tensor(mllm::ffi::__from_dlpack(dl_pack)); }); refl::GlobalDef().def("mllm.from_numpy", [](const tvm::ffi::Tensor& t) -> mllm::ffi::Tensor { - auto dl_pack = t.get()->ToDLPack(); + auto dl_pack = t.ToDLPack(); return ::mllm::ffi::Tensor(mllm::ffi::__from_dlpack(dl_pack)); }); @@ -345,6 +345,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::ObjectDef<::mllm::ffi::BaseOpObj>(); + refl::ObjectDef<::mllm::ffi::ParameterFileObj>(); refl::GlobalDef().def("mllm.BaseOp.load", [](const mllm::ffi::BaseOp& self, const mllm::ffi::ParameterFile& obj) -> void { self.get()->op_ptr_->load(obj.get()->pf_ptr_); }); diff --git a/mllm/ffi/vendors/tvm-ffi b/mllm/ffi/vendors/tvm-ffi index 46f735807..dcd07cfe2 160000 --- a/mllm/ffi/vendors/tvm-ffi +++ b/mllm/ffi/vendors/tvm-ffi @@ -1 +1 @@ -Subproject commit 46f73580780f2973e6ea3afb6d3a9d6f6ffd02cc +Subproject commit dcd07cfe27465287ee5b203b742e85dcfb99606a diff --git a/pymllm/README.md b/pymllm/README.md index e69de29bb..bee5ac41c 100644 --- a/pymllm/README.md +++ b/pymllm/README.md @@ -0,0 +1,3 @@ +# pymllm + +![pymllm-arch](../assets/pymllm-arch.png) diff --git a/pymllm/__init__.py b/pymllm/__init__.py index 1bd31cd6c..3f2488d27 100644 --- a/pymllm/__init__.py +++ b/pymllm/__init__.py @@ -2,48 +2,32 @@ # Licensed under the MIT License. from __future__ import annotations +import os +import sys -from . import ffi -from . import convertor -from . import utils -from . import quantize -from . import nn -from . import compile -from . import service -from . import backends -from .ffi import ( - # Floating point types - float32, - float16, - bfloat16, - # Signed integer types - int8, - int16, - int32, - int64, - # Unsigned integer types - uint8, - uint16, - uint32, - uint64, - # Bool type - boolean, - # Devices - cpu, - cuda, - qnn, - # Tensor and utilities - Tensor, - empty, - echo, - device, - is_torch_available, - is_numpy_available, - from_torch, - from_numpy, - zeros, - ones, - arange, - random, -) -from .nn.functional import matmul +__all__ = [] + + +def _has_mobile_libs() -> bool: + parent_dir = os.path.dirname(os.path.realpath(__file__)) + + # Platform-specific library names + if sys.platform.startswith("win32"): + lib_name = "MllmFFIExtension.dll" + elif sys.platform.startswith("darwin"): + lib_name = "MllmFFIExtension.dylib" + else: + lib_name = "MllmFFIExtension.so" + + lib_path = os.path.join(parent_dir, "lib", lib_name) + return os.path.exists(lib_path) + + +def is_mobile_available() -> bool: + return _has_mobile_libs() + + +if _has_mobile_libs(): + from . import mobile + + __all__.append("mobile") diff --git a/pymllm/__main__.py b/pymllm/__main__.py new file mode 100644 index 000000000..0b427fcee --- /dev/null +++ b/pymllm/__main__.py @@ -0,0 +1,39 @@ +def show_config() -> None: + from . import is_mobile_available + + mobile_enabled = str(is_mobile_available()).lower() + print(f"mllm mobile: {mobile_enabled}") + + # try import mllm_kernel, if true, print mllm_kernel config + try: + import mllm_kernel + + print(f"mllm_kernel: {mllm_kernel.__version__}") + except ImportError: + print("mllm_kernel: not found") + + +def main() -> None: + import argparse + + parser = argparse.ArgumentParser( + prog="pymllm", + description="pymllm helper commands.", + ) + parser.add_argument( + "command", + nargs="?", + choices=["show-config"], + help="Run helper command. Use 'show-config' to print config details.", + ) + args = parser.parse_args() + + if args.command == "show-config": + show_config() + return + + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/pymllm/backends/__init__.py b/pymllm/backends/__init__.py deleted file mode 100644 index 5e926d580..000000000 --- a/pymllm/backends/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) MLLM Team. -# Licensed under the MIT License. - -from . import cuda, qualcomm diff --git a/pymllm/backends/cuda/tilelang_compile_test.py b/pymllm/backends/cuda/tilelang_compile_test.py deleted file mode 100644 index 65a2e0071..000000000 --- a/pymllm/backends/cuda/tilelang_compile_test.py +++ /dev/null @@ -1,41 +0,0 @@ -import tilelang -import tilelang.language as T - - -@tilelang.jit( - out_idx=[-1], compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"] -) -def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): - @T.prim_func - def elem_add( - A: T.Tensor((M, N), in_dtype), - B: T.Tensor((M, N), in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads - ) as (bx, by): - A_shared = T.alloc_shared((block_M, block_N), in_dtype) - B_shared = T.alloc_shared((block_M, block_N), in_dtype) - C_local = T.alloc_fragment((block_M, block_N), out_dtype) - C_shared = T.alloc_shared((block_M, block_N), out_dtype) - - T.copy(A[by * block_M, bx * block_N], A_shared) - T.copy(B[by * block_M, bx * block_N], B_shared) - for local_y, local_x in T.Parallel(block_M, block_N): - C_local[local_y, local_x] = ( - A_shared[local_y, local_x] + B_shared[local_y, local_x] - ) - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return elem_add - - -def compile_test(): - M = 1024 - N = 1024 - config = {"block_M": 128, "block_N": 128, "threads": 128} - kernel = elementwise_add(M, N, **config, in_dtype="float16", out_dtype="float16") - source = kernel.get_kernel_source() - print(source) diff --git a/pymllm/configs/__init__.py b/pymllm/configs/__init__.py new file mode 100644 index 000000000..a23de035c --- /dev/null +++ b/pymllm/configs/__init__.py @@ -0,0 +1,14 @@ +"""Configuration module for pymllm.""" + +from pymllm.configs.global_config import GlobalConfig, get_global_config +from pymllm.configs.model_config import ModelConfig +from pymllm.configs.quantization_config import QuantizationConfig +from pymllm.configs.server_config import ServerConfig + +__all__ = [ + "GlobalConfig", + "get_global_config", + "ServerConfig", + "ModelConfig", + "QuantizationConfig", +] diff --git a/pymllm/configs/global_config.py b/pymllm/configs/global_config.py new file mode 100644 index 000000000..711de3cd1 --- /dev/null +++ b/pymllm/configs/global_config.py @@ -0,0 +1,342 @@ +"""Global configuration singleton aggregating all sub-configs.""" + +from __future__ import annotations + +import argparse +import types +from dataclasses import MISSING, dataclass, field, fields +from pathlib import Path +from typing import ( + Any, + Callable, + Literal, + Optional, + Sequence, + Union, + get_args, + get_origin, + get_type_hints, +) + +from pymllm.configs.server_config import ServerConfig +from pymllm.configs.model_config import ModelConfig +from pymllm.configs.quantization_config import QuantizationConfig + + +@dataclass +class GlobalConfig: + """Singleton that holds every sub-config pymllm needs. + + Usage:: + + from pymllm.configs import get_global_config + + cfg = get_global_config() + cfg.model.model_path + cfg.model.hidden_size + cfg.quantization.method + cfg.server.host + """ + + server: "ServerConfig" = field(default=None, repr=False) # type: ignore[assignment] + model: ModelConfig = field(default_factory=ModelConfig) + quantization: QuantizationConfig = field(default_factory=QuantizationConfig) + + _initialized: bool = field(default=False, repr=False) + + def __new__(cls): + if not hasattr(cls, "_instance") or cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __post_init__(self): + if self.server is None: + self.server = ServerConfig(model_path=None) + + @classmethod + def get_instance(cls) -> "GlobalConfig": + if not hasattr(cls, "_instance") or cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def reset(cls) -> None: + """Destroy the singleton (useful in tests).""" + cls._instance = None + + +def _parse_bool(value: Any) -> bool: + """Convert common CLI boolean spellings into ``bool``. + + This helper is intentionally permissive because CLI users often provide + booleans in different forms (for example ``true``, ``1``, ``yes``, + ``false``, ``0``, ``no``). The function raises ``argparse.ArgumentTypeError`` + to integrate naturally with ``argparse`` validation and error reporting. + """ + + if isinstance(value, bool): + return value + if value is None: + return True + + lowered = str(value).strip().lower() + if lowered in {"1", "true", "t", "yes", "y", "on"}: + return True + if lowered in {"0", "false", "f", "no", "n", "off"}: + return False + raise argparse.ArgumentTypeError( + f"Invalid boolean value: {value!r}. Expected one of true/false, 1/0, yes/no." + ) + + +def _unwrap_optional(annotation: Any) -> tuple[Any, bool]: + """Return ``(inner_type, is_optional)`` for Optional/Union annotations.""" + + origin = get_origin(annotation) + if origin not in (Union, types.UnionType): + return annotation, False + + args = [arg for arg in get_args(annotation) if arg is not type(None)] + if len(args) == 1 and len(get_args(annotation)) == 2: + return args[0], True + return annotation, False + + +def _converter_for_annotation(annotation: Any) -> Optional[Callable[[str], Any]]: + """Map a type annotation to an ``argparse`` converter. + + Only scalar, CLI-friendly annotations are supported. Complex runtime fields + (for example nested dict/object handles) are intentionally excluded from the + generated CLI surface to keep the interface predictable and safe. + """ + + inner, _ = _unwrap_optional(annotation) + origin = get_origin(inner) + if origin is not None: + if origin is Literal: + literal_values = get_args(inner) + if literal_values: + return type(literal_values[0]) + return str + return None + + if inner in (str, int, float): + return inner + if inner is Path: + return Path + return None + + +def _choices_for_annotation(annotation: Any) -> Optional[list]: + """Extract allowed values from a ``Literal`` annotation, if applicable.""" + + inner, _ = _unwrap_optional(annotation) + origin = get_origin(inner) + if origin is Literal: + return list(get_args(inner)) + return None + + +def _is_bool_annotation(annotation: Any) -> bool: + """Return ``True`` if annotation represents a bool/Optional[bool] field.""" + + inner, _ = _unwrap_optional(annotation) + return inner is bool + + +def _format_default_for_help(value: Any) -> str: + """Create a concise, readable default string for CLI help text.""" + + if value is MISSING: + return "" + if value is None: + return "None" + if isinstance(value, Path): + return str(value) + return repr(value) + + +def make_args( + parser: Optional[argparse.ArgumentParser] = None, +) -> argparse.ArgumentParser: + """Create an ``argparse`` parser with two-level GlobalConfig CLI options. + + The generated options follow the naming pattern ``--
.`` so + each sub-config can be configured independently: + + - ``server`` options map to :class:`ServerConfig` fields. + - ``model`` options map to :class:`ModelConfig` fields. + - ``quantization`` options map to :class:`QuantizationConfig` fields. + + Examples + -------- + - ``--server.host 0.0.0.0`` + - ``--server.port 8080`` + - ``--server.sleep_on_idle`` (implicit true) + - ``--server.sleep_on_idle false`` (explicit false) + - ``--quantization.method awq`` + + Design notes + ------------ + - Options are generated from dataclass metadata, which keeps the CLI surface + synchronized with config definitions and avoids manual drift. + - Parser defaults are suppressed (``argparse.SUPPRESS``), so ``read_args`` + can reliably detect whether a value was explicitly provided by the user. + - Only CLI-friendly scalar fields are exposed; runtime-only fields are + skipped automatically. + """ + + if parser is None: + parser = argparse.ArgumentParser( + prog="pymllm", + description="CLI options for configuring pymllm GlobalConfig.", + ) + + cfg = GlobalConfig.get_instance() + sections: list[tuple[str, Any]] = [ + ("server", cfg.server), + ("model", cfg.model), + ("quantization", cfg.quantization), + ] + + for section_name, section_obj in sections: + section_group = parser.add_argument_group( + f"{section_name} config", + f"Options for the '{section_name}' section of GlobalConfig.", + ) + type_hints = get_type_hints(type(section_obj)) + for dc_field in fields(section_obj): + if dc_field.name.startswith("_"): + continue + + annotation = type_hints.get(dc_field.name, dc_field.type) + option = f"--{section_name}.{dc_field.name}" + dest = f"{section_name}__{dc_field.name}" + default_value = getattr(section_obj, dc_field.name) + + if _is_bool_annotation(annotation): + section_group.add_argument( + option, + dest=dest, + nargs="?", + const=True, + type=_parse_bool, + default=argparse.SUPPRESS, + help=( + f"{section_name}.{dc_field.name} (bool, default: " + f"{_format_default_for_help(default_value)}). " + "Can be provided as a flag for true or with an explicit value." + ), + ) + continue + + converter = _converter_for_annotation(annotation) + if converter is None: + # Skip non-scalar or runtime-only fields (e.g. arbitrary objects). + continue + + choices = _choices_for_annotation(annotation) + kwargs: dict[str, Any] = dict( + dest=dest, + type=converter, + default=argparse.SUPPRESS, + ) + if choices is not None: + kwargs["choices"] = choices + choices_str = ", ".join(str(c) for c in choices) + kwargs["help"] = ( + f"{section_name}.{dc_field.name} " + f"{{choices: {choices_str}}} " + f"(default: {_format_default_for_help(default_value)})." + ) + else: + kwargs["help"] = ( + f"{section_name}.{dc_field.name} (default: " + f"{_format_default_for_help(default_value)})." + ) + + section_group.add_argument(option, **kwargs) + + return parser + + +def read_args( + argv: Optional[Sequence[str]] = None, + parser: Optional[argparse.ArgumentParser] = None, +) -> GlobalConfig: + """Parse CLI args and apply overrides to the singleton ``GlobalConfig``. + + Parameters + ---------- + argv + Optional argument vector. If ``None``, ``argparse`` reads from + ``sys.argv`` (standard CLI behavior). + parser + Optional parser to use. When omitted, this function builds one through + :func:`make_args`. + + Returns + ------- + GlobalConfig + The singleton config instance after CLI overrides have been applied. + + Behavior + -------- + 1. Parse all generated ``--section.field`` options. + 2. Apply only explicitly provided options (no accidental overwrite by parser + defaults). + 3. Rebuild ``ServerConfig`` when server fields change so validation in + ``ServerConfig.__post_init__`` and ``_validate`` remains enforced. + 4. Keep ``server.model_path`` and ``model.model_path`` aligned when only one + side is explicitly overridden (the same precedence used by runtime config + loading conventions). + """ + + if parser is None: + parser = make_args() + + namespace = parser.parse_args(argv) + parsed = vars(namespace) + cfg = GlobalConfig.get_instance() + + # Server: reconstruct to preserve validation behavior. + from pymllm.configs.server_config import ServerConfig + + server_updates: dict[str, Any] = {} + for dc_field in fields(cfg.server): + key = f"server__{dc_field.name}" + if key in parsed: + server_updates[dc_field.name] = parsed[key] + if server_updates: + server_values = { + dc_field.name: getattr(cfg.server, dc_field.name) + for dc_field in fields(cfg.server) + } + server_values.update(server_updates) + cfg.server = ServerConfig(**server_values) + + # Model / Quantization: in-place updates are sufficient. + for section_name, section_obj in ( + ("model", cfg.model), + ("quantization", cfg.quantization), + ): + for dc_field in fields(section_obj): + key = f"{section_name}__{dc_field.name}" + if key in parsed: + setattr(section_obj, dc_field.name, parsed[key]) + + # Keep model path synchronized when only one side is explicitly overridden. + server_model_overridden = "server__model_path" in parsed + model_model_overridden = "model__model_path" in parsed + if server_model_overridden and not model_model_overridden: + cfg.model.model_path = cfg.server.model_path + elif model_model_overridden and not server_model_overridden: + cfg.server.model_path = cfg.model.model_path + + cfg._initialized = True + return cfg + + +def get_global_config() -> GlobalConfig: + """Return the global config singleton.""" + return GlobalConfig.get_instance() diff --git a/pymllm/configs/model_config.py b/pymllm/configs/model_config.py new file mode 100644 index 000000000..c23dff1d9 --- /dev/null +++ b/pymllm/configs/model_config.py @@ -0,0 +1,31 @@ +"""Lightweight model configuration: path + HuggingFace config handle.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + + +@dataclass +class ModelConfig: + """Minimal model config wrapping a HuggingFace PretrainedConfig. + + Attributes on ``hf_config`` are flattened onto this object:: + + cfg = get_global_config().model + cfg.hidden_size # -> hf_config.hidden_size + cfg.vocab_size # -> hf_config.vocab_size + cfg.text_config # -> hf_config.text_config (multimodal) + """ + + # Populated at runtime via ``transformers.AutoConfig.from_pretrained`` + hf_config: Optional[Any] = field(default=None, repr=False) + + def __getattr__(self, name: str) -> Any: + hf = object.__getattribute__(self, "hf_config") + if hf is not None and hasattr(hf, name): + return getattr(hf, name) + raise AttributeError( + f"'{type(self).__name__}' has no attribute '{name}' " + f"(also not found on hf_config)" + ) diff --git a/pymllm/configs/quantization_config.py b/pymllm/configs/quantization_config.py new file mode 100644 index 000000000..850ea82b8 --- /dev/null +++ b/pymllm/configs/quantization_config.py @@ -0,0 +1,18 @@ +"""Quantization settings for model weights and KV cache.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal, Optional + + +@dataclass +class QuantizationConfig: + """Quantization configuration for weights and KV cache.""" + + # Weight quantization method (e.g. "awq", "gptq", "fp8", None for no quant) + method: Optional[str] = None + # KV cache data type override + kv_cache_dtype: Literal[ + "auto", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2" + ] = "auto" diff --git a/pymllm/configs/server_config.py b/pymllm/configs/server_config.py new file mode 100644 index 000000000..8727f7c13 --- /dev/null +++ b/pymllm/configs/server_config.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Literal, Optional +from dataclasses import dataclass, field + + +@dataclass +class ServerConfig: + """Centralized runtime configuration for the MLLM server.""" + + # --------------------------------------------------------------------- # + # Model and tokenizer configuration + # --------------------------------------------------------------------- # + model_path: Optional[Path] = None + tokenizer_path: Optional[Path] = None + tokenizer_mode: Literal["auto", "slow", "fast"] = "auto" + load_format: Literal["auto", "safetensors"] = "auto" + trust_remote_code: bool = False + download_dir: Optional[Path] = None + context_length: Optional[int] = None + dtype: Literal["auto", "float16", "bfloat16", "float32"] = "auto" + + # --------------------------------------------------------------------- # + # HTTP / API server + # --------------------------------------------------------------------- # + host: str = "127.0.0.1" + port: int = 30000 + fastapi_root_path: str = "" + api_key: Optional[str] = None + admin_api_key: Optional[str] = None + served_model_name: Optional[str] = None + file_storage_path: Path = Path("mllm_storage") + + # --------------------------------------------------------------------- # + # Scheduling and memory + # --------------------------------------------------------------------- # + mem_fraction_static: Optional[float] = None + max_running_requests: Optional[int] = 1 + max_queued_requests: Optional[int] = None + max_total_tokens: Optional[int] = None + chunked_prefill_size: Optional[int] = None + max_prefill_tokens: Optional[int] = None + schedule_policy: Literal["auto", "fcfs"] = "fcfs" + schedule_conservativeness: float = 1.0 + sleep_on_idle: bool = False + stream_interval: int = 1 + stream_output: bool = True + + # --------------------------------------------------------------------- # + # Device + # --------------------------------------------------------------------- # + base_gpu_id: int = 0 + + # --------------------------------------------------------------------- # + # Backend / acceleration + # --------------------------------------------------------------------- # + attention_backend: Literal["auto", "flashinfer"] = "auto" + gdn_decode_backend: Literal["auto", "flashinfer", "mllm_kernel", "pytorch"] = "auto" + sampling_backend: Optional[str] = None + disable_cuda_graph: bool = False + enable_torch_compile: bool = False + torch_compile_max_bs: int = 32 + random_seed: Optional[int] = 42 + + # --------------------------------------------------------------------- # + # Output parsers (reasoning / tool calls) + # --------------------------------------------------------------------- # + reasoning_parser: Optional[str] = None # e.g. "deepseek-r1", "qwen3" + tool_call_parser: Optional[str] = None # e.g. "qwen25", "llama3", "hermes" + + # --------------------------------------------------------------------- # + # Logging and observability + # --------------------------------------------------------------------- # + log_level: Literal["debug", "info", "warning", "error", "critical"] = "info" + enable_metrics: bool = False + show_time_cost: bool = False + # Log prefill/decode throughput stats every N decode batches (0 = disabled) + decode_log_interval: int = 40 + + # --------------------------------------------------------------------- # + # Feature switches + # --------------------------------------------------------------------- # + enable_shared_queue: bool = False # Use shared memory queue for fast IPC + disable_radix_cache: bool = False # Disable radix-tree prefix caching + radix_cache_page_size: int = 1 # Number of tokens per KV-pool page in RadixCache + + # CUDA IPC transport for multimodal GPU tensors. + # Requires enable_shared_queue=True to take effect. + # + # Three transport modes (mutually exclusive for GPU tensors): + # + # "default" + # GPU tensors are moved to CPU first (GPU→CPU copy), then placed in + # POSIX shared memory via share_memory_(). Safe but adds a device copy. + # + # "cuda_ipc" + # GPU tensors stay on GPU. Each tensor is wrapped in a + # TransportProxyTensor whose __getstate__ calls storage._share_cuda_() + # to obtain an IPC handle; the receiver reconstructs via + # UntypedStorage._new_shared_cuda(*handle). Simple, but the underlying + # GPU allocation is never freed until the sender process exits + # (PyTorch limitation) -- can leak GPU memory in long-running services. + # + # "cuda_ipc_pool" [recommended for production] + # GPU tensors are copied into a pre-allocated fixed-size GPU workspace + # (MmItemMemoryPool). Each outgoing tensor occupies a "chunk" of the + # pool; the chunk's IPC handle is sent via CudaIpcTensorTransportProxy. + # After the receiver finishes copying data it increments a shared-memory + # sync flag; a background recycler thread in the sender watches these + # flags and returns chunks to the available pool. No GPU memory is leaked. + tensor_transport_mode: str = "default" # one of: default, cuda_ipc, cuda_ipc_pool + + # Size of the pre-allocated CUDA IPC memory pool in MB. + # Only used when tensor_transport_mode == "cuda_ipc_pool". + cuda_ipc_pool_size_mb: int = 512 + + # How often (seconds) the pool recycler thread wakes up. + cuda_ipc_recycle_interval: float = 0.1 + # enable_lora: bool = False + # max_loaded_loras: Optional[int] = None + # max_loras_per_batch: int = 8 + # lora_backend: Literal["triton", "csgmv", "torch_native"] = "csgmv" + # enable_multimodal: bool = False + # speculative_algorithm: Optional[str] = None + # speculative_draft_model_path: Optional[Path] = None + # speculative_num_steps: Optional[int] = None + # speculative_num_draft_tokens: Optional[int] = None + + # --------------------------------------------------------------------- # + # Extra + # --------------------------------------------------------------------- # + extra_options: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + if self.tokenizer_path is None: + self.tokenizer_path = self.model_path + if self.served_model_name is None: + self.served_model_name = str(self.model_path) + self._validate() + + def _validate(self) -> None: + valid_modes = {"default", "cuda_ipc", "cuda_ipc_pool"} + if self.tensor_transport_mode not in valid_modes: + raise ValueError( + f"`tensor_transport_mode` must be one of {valid_modes}, " + f"got {self.tensor_transport_mode!r}." + ) + if self.tensor_transport_mode != "default" and not self.enable_shared_queue: + raise ValueError( + "`tensor_transport_mode` != 'default' requires `enable_shared_queue=True`." + ) + if self.cuda_ipc_pool_size_mb <= 0: + raise ValueError("`cuda_ipc_pool_size_mb` must be > 0.") + if self.port <= 0 or self.port > 65535: + raise ValueError("`port` must be in range [1, 65535].") + if self.max_prefill_tokens is not None and self.max_prefill_tokens <= 0: + raise ValueError("`max_prefill_tokens` must be > 0.") + if self.stream_interval <= 0: + raise ValueError("`stream_interval` must be > 0.") + if self.mem_fraction_static is not None and not ( + 0.0 < self.mem_fraction_static < 1.0 + ): + raise ValueError("`mem_fraction_static` must be in (0.0, 1.0).") + if self.max_running_requests is not None and self.max_running_requests <= 0: + raise ValueError("`max_running_requests` must be > 0 when set.") + if self.max_queued_requests is not None and self.max_queued_requests < 0: + raise ValueError("`max_queued_requests` must be >= 0 when set.") + if self.radix_cache_page_size < 1: + raise ValueError("`radix_cache_page_size` must be >= 1.") + if self.schedule_conservativeness <= 0: + raise ValueError("`schedule_conservativeness` must be > 0.") diff --git a/pymllm/engine/__init__.py b/pymllm/engine/__init__.py new file mode 100644 index 000000000..50f2b7249 --- /dev/null +++ b/pymllm/engine/__init__.py @@ -0,0 +1,8 @@ +"""Engine module for pymllm.""" + +from pymllm.engine.forward_batch import ForwardBatch, ForwardMode + +__all__ = [ + "ForwardBatch", + "ForwardMode", +] diff --git a/pymllm/engine/forward_batch.py b/pymllm/engine/forward_batch.py new file mode 100644 index 000000000..428da7b66 --- /dev/null +++ b/pymllm/engine/forward_batch.py @@ -0,0 +1,191 @@ +"""ForwardMode and ForwardBatch for pymllm. + +Simplified forward-batch abstraction: no speculative decoding, no +encoder-decoder support, and no distributed-attention complexity (DP/TP +head splitting is handled at the layer level by the model code, not here). + +Typical data flow +----------------- + ModelRunner builds a ForwardBatch + ↓ + attn_backend.init_forward_metadata(forward_batch) + ↓ + model.forward(input_ids, positions, forward_batch) + ↓ + RadixAttention.forward(q, k, v, forward_batch) + ↓ + forward_batch.attn_backend.forward(q, k, v, layer, forward_batch) +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import IntEnum, auto +from typing import TYPE_CHECKING, List, Optional + +import torch + +if TYPE_CHECKING: + from pymllm.layers.attention.attention_backend import AttentionBackend + from pymllm.mem_cache.memory_pool import KVPool, ReqToTokenPool + + +# --------------------------------------------------------------------------- +# ForwardMode +# --------------------------------------------------------------------------- + + +class ForwardMode(IntEnum): + """Describes what kind of forward pass is being performed. + + Covers standard prefill / decode inference without speculative decoding. + """ + + # Prefill / extend: process new tokens. The KV cache of the prefix (if + # any) is already populated (e.g. shared system-prompt via radix cache). + EXTEND = auto() + + # Decode: generate exactly one new token per sequence. + DECODE = auto() + + # Mixed: a chunked-prefill batch that contains both extend and decode + # sequences simultaneously. + MIXED = auto() + + # Idle: no sequences to process (used with data-parallel workers when some + # ranks have no allocated sequences). + IDLE = auto() + + # ---- helpers ---- + + def is_extend(self) -> bool: + """True for EXTEND or MIXED (i.e. any prefill-style pass).""" + return self in (ForwardMode.EXTEND, ForwardMode.MIXED) + + def is_prefill(self) -> bool: + """Alias for ``is_extend()``.""" + return self.is_extend() + + def is_decode(self) -> bool: + return self == ForwardMode.DECODE + + def is_mixed(self) -> bool: + return self == ForwardMode.MIXED + + def is_idle(self) -> bool: + return self == ForwardMode.IDLE + + def is_decode_or_idle(self) -> bool: + return self == ForwardMode.DECODE or self == ForwardMode.IDLE + + +# --------------------------------------------------------------------------- +# ForwardBatch +# --------------------------------------------------------------------------- + + +@dataclass +class ForwardBatch: + """All tensors required by a single forward pass through the model. + + Parameters + ---------- + forward_mode + The kind of pass being performed (EXTEND / DECODE / MIXED / IDLE). + batch_size + Number of sequences in the batch. + input_ids + Token ids for every position in the batch, shape ``[num_tokens]``. + For decode, ``num_tokens == batch_size``; for extend, + ``num_tokens == extend_num_tokens``. + req_pool_indices + Index of each sequence in ``ReqToTokenPool``, shape ``[batch_size]`` + (int32 or int64, on the target device). + seq_lens + Total (prefix + new) length of each sequence, shape ``[batch_size]`` + (int32). + out_cache_loc + KV-pool slot that each *output* token is written to, shape + ``[num_tokens]`` (int64). + seq_lens_sum + Python ``int`` equal to ``seq_lens.sum()``. Cached to avoid repeated + device-to-host syncs. + seq_lens_cpu + CPU copy of ``seq_lens`` (optional; used by some attention backends + for plan computation without a device sync). + positions + Token position for each input token, shape ``[num_tokens]`` + (int32 or int64). + extend_num_tokens + Total number of new (non-prefix) tokens across the batch. Only set + during EXTEND / MIXED passes. + extend_seq_lens + Number of *new* tokens for each sequence, shape ``[batch_size]`` + (int32). Only set during EXTEND / MIXED. + extend_prefix_lens + Length of the already-cached prefix for each sequence, + shape ``[batch_size]`` (int32). Only set during EXTEND / MIXED. + extend_start_loc + Cumulative start offset of each sequence in the flattened extend + token stream, shape ``[batch_size]`` (int32). + extend_prefix_lens_cpu + CPU list mirror of ``extend_prefix_lens``. + extend_seq_lens_cpu + CPU list mirror of ``extend_seq_lens``. + return_logprob + Whether to compute per-token log-probabilities. + top_logprobs_nums + Number of top log-probs to return per sequence (None or list of ints). + req_to_token_pool + Reference to the ``ReqToTokenPool`` (set by the model runner). + token_to_kv_pool + Reference to the ``KVPool`` (set by the model runner). + attn_backend + The attention backend to use (set by the model runner before calling + ``model.forward``). + """ + + # ---- required fields (positional) ---- + forward_mode: ForwardMode + batch_size: int + input_ids: torch.Tensor # [num_tokens] + req_pool_indices: torch.Tensor # [batch_size] int32/int64 + seq_lens: torch.Tensor # [batch_size] int32 + out_cache_loc: torch.Tensor # [num_tokens] int64 + seq_lens_sum: int # python int + + # ---- optional metadata ---- + + # CPU mirror of seq_lens + seq_lens_cpu: Optional[torch.Tensor] = None + + # Position encoding – shape [num_tokens], int32 or int64 + positions: Optional[torch.Tensor] = None + + # ---- extend / prefill specific ---- + extend_num_tokens: Optional[int] = None + extend_seq_lens: Optional[torch.Tensor] = None # [batch_size] int32 + extend_prefix_lens: Optional[torch.Tensor] = None # [batch_size] int32 + extend_start_loc: Optional[torch.Tensor] = None # [batch_size] int32 + extend_prefix_lens_cpu: Optional[List[int]] = None + extend_seq_lens_cpu: Optional[List[int]] = None + + # ---- logprob options ---- + return_logprob: bool = False + top_logprobs_nums: Optional[List[int]] = None + + # ---- memory pools (set by model runner) ---- + req_to_token_pool: Optional["ReqToTokenPool"] = None + token_to_kv_pool: Optional["KVPool"] = None + + # ---- attention backend (set by model runner) ---- + attn_backend: Optional["AttentionBackend"] = None + + # ---- multimodal M-RoPE ---- + # Per-request position delta for M-RoPE decode steps. + # Set by the model during prefill; consumed during decode to offset positions. + mrope_position_deltas: Optional[torch.Tensor] = None # [batch_size] int64 + + # ---- multimodal vision inputs (extend / prefill only) ---- + pixel_values: Optional[torch.Tensor] = None + image_grid_thw: Optional[torch.Tensor] = None diff --git a/pymllm/engine/io_struct.py b/pymllm/engine/io_struct.py new file mode 100644 index 000000000..06c8d78d6 --- /dev/null +++ b/pymllm/engine/io_struct.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, Iterator, List, Optional, Union + + +@dataclass +class BaseReq: + rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True) + + def regenerate_rid(self) -> Union[str, List[str]]: + if isinstance(self.rid, list): + self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))] + else: + self.rid = uuid.uuid4().hex + return self.rid + + +@dataclass +class BaseBatchReq: + rids: List[str] + + def regenerate_rids(self) -> List[str]: + self.rids = [uuid.uuid4().hex for _ in range(len(self.rids))] + return self.rids + + +@dataclass +class GenerateReqInput(BaseReq): + text: Optional[Union[List[str], str]] = None + input_ids: Optional[Union[List[List[int]], List[int]]] = None + sampling_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None + return_logprob: Optional[Union[List[bool], bool]] = None + logprob_start_len: Optional[Union[List[int], int]] = None + top_logprobs_num: Optional[Union[List[int], int]] = None + stream: bool = False + + # Multimodal placeholders. + image_data: Optional[Any] = None + video_data: Optional[Any] = None + audio_data: Optional[Any] = None + + # Runtime extension placeholders. + lora_path: Optional[Union[List[Optional[str]], str]] = None + session_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None + extra_options: Dict[str, Any] = field(default_factory=dict) + + # Derived fields populated by normalization. + is_single: bool = field(default=True, init=False) + batch_size: int = field(default=1, init=False) + + def normalize_batch_and_arguments(self) -> None: + self._validate_inputs() + self._determine_batch_size() + + def _validate_inputs(self) -> None: + has_text = self.text is not None + has_input_ids = self.input_ids is not None + if has_text == has_input_ids: + raise ValueError("Exactly one of `text` or `input_ids` must be provided.") + + def _determine_batch_size(self) -> None: + if self.text is not None: + if isinstance(self.text, str): + self.is_single = True + self.batch_size = 1 + else: + if len(self.text) == 0: + raise ValueError("`text` cannot be an empty list.") + self.is_single = False + self.batch_size = len(self.text) + return + + assert self.input_ids is not None + if len(self.input_ids) == 0: + raise ValueError("`input_ids` cannot be empty.") + if isinstance(self.input_ids[0], int): + self.is_single = True + self.batch_size = 1 + else: + self.is_single = False + self.batch_size = len(self.input_ids) + + def __getitem__(self, i: int) -> "GenerateReqInput": + if i < 0 or i >= self.batch_size: + raise IndexError(f"index {i} out of range for batch size {self.batch_size}") + if self.batch_size == 1: + return self + return GenerateReqInput( + rid=self._pick(self.rid, i), + text=self._pick(self.text, i), + input_ids=self._pick(self.input_ids, i), + sampling_params=self._pick(self.sampling_params, i), + return_logprob=self._pick(self.return_logprob, i), + logprob_start_len=self._pick(self.logprob_start_len, i), + top_logprobs_num=self._pick(self.top_logprobs_num, i), + stream=self.stream, + image_data=self._pick(self.image_data, i), + video_data=self._pick(self.video_data, i), + audio_data=self._pick(self.audio_data, i), + lora_path=self._pick(self.lora_path, i), + session_params=self._pick(self.session_params, i), + extra_options=self.extra_options.copy(), + ) + + @staticmethod + def _pick(value: Any, i: int) -> Any: + if isinstance(value, list): + return value[i] + return value + + def to_request_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + for key, value in { + "rid": self.rid, + "text": self.text, + "input_ids": self.input_ids, + "sampling_params": self.sampling_params, + "return_logprob": self.return_logprob, + "logprob_start_len": self.logprob_start_len, + "top_logprobs_num": self.top_logprobs_num, + "stream": self.stream, + "image_data": self.image_data, + "video_data": self.video_data, + "audio_data": self.audio_data, + "lora_path": self.lora_path, + "session_params": self.session_params, + }.items(): + if value is not None: + payload[key] = value + payload.update(self.extra_options) + return payload + + +@dataclass +class TokenizedGenerateReqInput(BaseReq): + # The decoded text passed to the tokenizer (empty string if only input_ids + # were provided by the caller). + input_text: str = "" + # Token IDs produced by the tokenizer. + input_ids: List[int] = field(default_factory=list) + # Multimodal inputs (processor output, e.g. pixel_values, or raw image / + # audio / video data when no processor is available). ``None`` means the + # request is text-only. + mm_inputs: Optional[Dict[str, Any]] = None + # Raw sampling parameters dict (parsed into a SamplingParams object by the + # model runner when needed). + sampling_params: Dict[str, Any] = field(default_factory=dict) + stream: bool = False + return_logprob: bool = False + logprob_start_len: int = -1 + top_logprobs_num: int = 0 + lora_path: Optional[str] = None + session_params: Optional[Dict[str, Any]] = None + + +@dataclass +class BatchTokenizedGenerateReqInput(BaseBatchReq): + reqs: List[TokenizedGenerateReqInput] + + def __len__(self) -> int: + return len(self.reqs) + + def __getitem__(self, i: int) -> TokenizedGenerateReqInput: + return self.reqs[i] + + def __iter__(self) -> Iterator[TokenizedGenerateReqInput]: + return iter(self.reqs) + + +@dataclass +class BatchTokenIDOutput(BaseBatchReq): + finished_reasons: List[Optional[str]] + decode_ids: List[int] + read_offsets: List[int] + output_ids: Optional[List[int]] + skip_special_tokens: List[bool] + prompt_tokens: List[int] + completion_tokens: List[int] + input_token_logprobs_val: List[float] = field(default_factory=list) + input_token_logprobs_idx: List[int] = field(default_factory=list) + output_token_logprobs_val: List[float] = field(default_factory=list) + output_token_logprobs_idx: List[int] = field(default_factory=list) + input_top_logprobs_val: List[List[float]] = field(default_factory=list) + input_top_logprobs_idx: List[List[int]] = field(default_factory=list) + output_top_logprobs_val: List[List[float]] = field(default_factory=list) + output_top_logprobs_idx: List[List[int]] = field(default_factory=list) + + +@dataclass +class BatchStrOutput(BaseBatchReq): + finished_reasons: List[Optional[str]] + output_strs: List[str] + output_ids: Optional[List[int]] + prompt_tokens: List[int] + completion_tokens: List[int] + input_token_logprobs_val: List[float] = field(default_factory=list) + input_token_logprobs_idx: List[int] = field(default_factory=list) + output_token_logprobs_val: List[float] = field(default_factory=list) + output_token_logprobs_idx: List[int] = field(default_factory=list) + input_top_logprobs_val: List[List[float]] = field(default_factory=list) + input_top_logprobs_idx: List[List[int]] = field(default_factory=list) + output_top_logprobs_val: List[List[float]] = field(default_factory=list) + output_top_logprobs_idx: List[List[int]] = field(default_factory=list) diff --git a/pymllm/engine/launch.py b/pymllm/engine/launch.py new file mode 100644 index 000000000..e5214511f --- /dev/null +++ b/pymllm/engine/launch.py @@ -0,0 +1,620 @@ +import asyncio +import atexit +import logging +import os +import uuid +from pathlib import Path +from typing import Any, AsyncIterator, Dict, List, Optional, Union + +import torch +import torch.multiprocessing as mp +from transformers import AutoConfig +from huggingface_hub import snapshot_download + +try: + from pyfiglet import figlet_format + from termcolor import colored + + HAS_BANNER_LIBS = True +except ImportError: + HAS_BANNER_LIBS = False + +from pymllm.configs import get_global_config +from pymllm.engine.io_struct import GenerateReqInput +from pymllm.orchestrator.ipc_utils import make_ipc_address +from pymllm.orchestrator.request_response_process import ( + ReqState, + RequestResponseProcess, +) +from pymllm.orchestrator.tokenizer_process import run_tokenizer_process +from pymllm.orchestrator.scheduler_process import run_scheduler_process +from pymllm.orchestrator.detokenizer_process import run_detokenizer_process + +logger = logging.getLogger(__name__) + +# Standard HuggingFace config fields that indicate max output tokens, +# checked in priority order. +_MAX_NEW_TOKENS_FIELDS = ( + "max_new_tokens", + "max_tokens", + "max_completion_tokens", +) + + +def _normalize_eos_raw(raw) -> List[int]: + """Normalize a raw eos_token_id value (int, list, or None) to a list.""" + if raw is None: + return [] + if isinstance(raw, int): + return [raw] + if isinstance(raw, (list, tuple)): + return [x for x in raw if isinstance(x, int)] + return [] + + +def _get_eos_token_ids(hf_config, model_path=None) -> List[int]: + """Extract EOS token ID(s) from a HuggingFace model config. + + Searches in priority order: + 1. ``hf_config.eos_token_id`` (top-level, standard models) + 2. ``hf_config.text_config.eos_token_id`` (VL / multimodal models) + 3. ``generation_config.json`` (many models store EOS here) + 4. ``tokenizer_config.json`` via AutoTokenizer (last resort) + """ + if hf_config is None: + return [] + + # 1. Top-level config + ids = _normalize_eos_raw(getattr(hf_config, "eos_token_id", None)) + if ids: + return ids + + # 2. Nested text_config (VL / multimodal models like Qwen3-VL) + text_config = getattr(hf_config, "text_config", None) + if text_config is not None: + ids = _normalize_eos_raw(getattr(text_config, "eos_token_id", None)) + if ids: + return ids + + # 3. generation_config.json (lightweight, just reads a JSON file) + if model_path is not None: + try: + from transformers import GenerationConfig + + gen_cfg = GenerationConfig.from_pretrained(str(model_path)) + ids = _normalize_eos_raw(getattr(gen_cfg, "eos_token_id", None)) + if ids: + logger.info("EOS token IDs from generation_config.json: %s", ids) + return ids + except Exception: + pass + + # 4. Tokenizer (last resort) + if model_path is not None: + try: + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) + if tok.eos_token_id is not None: + ids = [tok.eos_token_id] + logger.info("EOS token ID from tokenizer: %s", ids) + return ids + except Exception: + pass + + return [] + + +def _get_model_default_max_new_tokens(hf_config) -> Optional[int]: + """Extract max output token limit from a HuggingFace model config. + + Checks standard fields in priority order. Returns ``None`` when the + config does not specify any recognised output-length field. + """ + if hf_config is None: + return None + for field_name in _MAX_NEW_TOKENS_FIELDS: + value = getattr(hf_config, field_name, None) + if value is not None and isinstance(value, int) and value > 0: + logger.info( + "Using model config %s=%d as default max_new_tokens", + field_name, + value, + ) + return value + return None + + +class Engine: + def __init__(self): + self._subprocesses: List[mp.Process] = [] + self._rr_process: Optional[RequestResponseProcess] = None + self._config_logging() + self._set_default_torch_dtype() + self._check_model_and_tokenizer() + + def launch(self) -> None: + self._launch_processes() + atexit.register(self.shutdown) + + def _launch_processes(self) -> None: + """Spawn all subprocess workers and wire up ZMQ IPC channels.""" + mp.set_start_method("spawn", force=True) + uid = str(os.getpid()) + + # IPC addresses for ZMQ communication between processes + addr_request_response_to_tokenizer: str = make_ipc_address( + "request_response_to_tokenizer", uid + ) + addr_tokenizer_to_scheduler: str = make_ipc_address( + "tokenizer_to_scheduler", uid + ) + addr_scheduler_to_detokenizer: str = make_ipc_address( + "scheduler_to_detokenizer", uid + ) + addr_detokenizer_to_request_response: str = make_ipc_address( + "detokenizer_to_request_response", uid + ) + # Record all subprocesses + procs_and_readers: List[tuple] = [] + + # Config dict for the tokenizer subprocess (must be picklable). + cfg = get_global_config() + enable_shared_queue = cfg.server.enable_shared_queue + transport_mode: str = ( + cfg.server.tensor_transport_mode + ) # "default" | "cuda_ipc" | "cuda_ipc_pool" + + # Create shared queue if enabled. + # Note: the MmItemMemoryPool (for "cuda_ipc_pool") is created *inside* + # the tokenizer subprocess after CUDA is initialised. The queue here + # is constructed without a pool; TokenizerProcess._ensure_pool() will + # swap in a pool-aware queue at runtime. + shared_queue = None + if enable_shared_queue: + from pymllm.orchestrator.shared_memory_queue import TensorQueue as _TQ + + # Construct with the configured transport mode. The pool is not + # supplied here; it will be lazily initialised inside the subprocess. + shared_queue = _TQ( + maxsize=1000, + transport_mode=transport_mode, + pool=None, # pool initialised lazily inside TokenizerProcess + ) + logger.info( + "Shared memory queue enabled for fast IPC (transport_mode=%s)", + transport_mode, + ) + + tokenizer_cfg: Dict[str, Any] = { + "tokenizer_path": str(cfg.server.tokenizer_path), + "tokenizer_mode": cfg.server.tokenizer_mode, + "trust_remote_code": cfg.server.trust_remote_code, + "context_length": cfg.server.context_length, + "hf_config": cfg.model.hf_config, + "enable_shared_queue": enable_shared_queue, + "tensor_transport_mode": transport_mode, + "cuda_ipc_pool_size_mb": cfg.server.cuda_ipc_pool_size_mb, + "cuda_ipc_recycle_interval": cfg.server.cuda_ipc_recycle_interval, + "log_level": cfg.server.log_level, + } + + # Tokenizer + tokenizer_reader, tokenizer_writer = mp.Pipe(duplex=False) + tokenizer_proc = mp.Process( + target=run_tokenizer_process, + args=( + addr_request_response_to_tokenizer, + addr_tokenizer_to_scheduler, + tokenizer_writer, + tokenizer_cfg, + shared_queue, # Pass shared queue + ), + daemon=True, + ) + procs_and_readers.append((tokenizer_proc, tokenizer_reader, "tokenizer")) + + # Determine default max_new_tokens from model config (if available) + model_max_new_tokens = _get_model_default_max_new_tokens( + cfg.model.hf_config + ) + scheduler_kwargs = {} + if model_max_new_tokens is not None: + scheduler_kwargs["default_max_new_tokens"] = model_max_new_tokens + + # Extract EOS token ID(s) from model config + eos_token_ids = _get_eos_token_ids(cfg.model.hf_config, model_path=cfg.server.model_path) + if eos_token_ids: + scheduler_kwargs["eos_token_ids"] = eos_token_ids + logger.info("EOS token IDs for scheduler: %s", eos_token_ids) + + # Model runner config — passed to the scheduler process which now + # owns the model runner in-process (sglang-style architecture). + scheduler_kwargs["server_config"] = cfg.server + scheduler_kwargs["model_config"] = cfg.model + scheduler_kwargs["gpu_id"] = cfg.server.base_gpu_id + + # Scheduler (+ in-process model runner) + scheduler_reader, scheduler_writer = mp.Pipe(duplex=False) + scheduler_proc = mp.Process( + target=run_scheduler_process, + args=( + addr_tokenizer_to_scheduler, + addr_scheduler_to_detokenizer, + scheduler_writer, + shared_queue, # Pass shared queue + enable_shared_queue, # Pass flag + transport_mode, # Pass tensor transport mode + cfg.server.log_level, # Pass log level + ), + kwargs=scheduler_kwargs, + daemon=True, + ) + procs_and_readers.append((scheduler_proc, scheduler_reader, "scheduler")) + + # Detokenizer + detokenizer_reader, detokenizer_writer = mp.Pipe(duplex=False) + detokenizer_proc = mp.Process( + target=run_detokenizer_process, + args=( + addr_scheduler_to_detokenizer, + addr_detokenizer_to_request_response, + detokenizer_writer, + tokenizer_cfg, + ), + daemon=True, + ) + procs_and_readers.append((detokenizer_proc, detokenizer_reader, "detokenizer")) + + # Start all subprocesses + for proc, _, name in procs_and_readers: + proc.start() + self._subprocesses.append(proc) + logger.info("Started %s process (pid=%s)", name, proc.pid) + + # Wait for readiness signals + for _, reader, name in procs_and_readers: + try: + msg = reader.recv() + except EOFError: + raise RuntimeError(f"{name} process died before signalling readiness") + if msg.get("status") != "ready": + raise RuntimeError(f"{name} process failed to initialise: {msg}") + logger.info("%s process ready", name) + + # RR Process is current main process — only bind ZMQ sockets here. + # Background tasks are started lazily by listen() on the first + # add_request(), so they always run on the correct event loop. + self._rr_process = RequestResponseProcess( + send_to_tokenizer_addr=addr_request_response_to_tokenizer, + recv_from_detokenizer_addr=addr_detokenizer_to_request_response, + ) + self._rr_process.start() + logger.info("RequestResponseProcess sockets bound") + + # Print colorful gradient ASCII art banner + if HAS_BANNER_LIBS: + try: + text = figlet_format("pymllm", font="slant") + fired_up = figlet_format("FIRED UP!", font="slant") + + # Apply blue-purple gradient + lines = text.strip().split("\n") + colors_cycle = ["blue", "cyan", "blue", "magenta", "magenta"] + for i, line in enumerate(lines): + color = colors_cycle[i % len(colors_cycle)] + print(colored(line, color, attrs=["bold"])) + + # Print "FIRED UP!" in bright magenta + for line in fired_up.strip().split("\n"): + print(colored(line, "magenta", attrs=["bold"])) + print() + except Exception as e: + logger.debug(f"Failed to print banner: {e}") + print("🚀 pymllm FIRED UP! 🚀\n") + else: + print("🚀 pymllm FIRED UP! 🚀\n") + + def generate( + self, + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + image_data: Optional[Any] = None, + audio_data: Optional[Any] = None, + video_data: Optional[Any] = None, + return_logprob: Optional[Union[List[bool], bool]] = None, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[Union[List[Optional[str]], str]] = None, + session_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, + stream: bool = False, + rid: Optional[Union[List[str], str]] = None, + **kwargs, + ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + """Synchronous, non-streaming generation entry point. + + Accepts a single prompt (``str``) or a batch (``List[str]``). Returns a + single result dict for single inputs and a list of result dicts for batch + inputs, preserving the input order. + """ + rid = self._make_rids(rid, prompt, input_ids) + request = GenerateReqInput( + rid=rid, + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + stream=stream, + image_data=image_data, + audio_data=audio_data, + video_data=video_data, + lora_path=lora_path, + session_params=session_params, + extra_options=kwargs, + ) + request.normalize_batch_and_arguments() + + async def _run() -> Union[Dict[str, Any], List[Dict[str, Any]]]: + result = await self._rr_process.add_request(request) + if request.is_single: + single_rid = rid if isinstance(rid, str) else rid[0] + return await self._wait_for_final_result(single_rid, result) # type: ignore[arg-type] + # Batch: wait for every sub-request concurrently. + rids_list: List[str] = rid if isinstance(rid, list) else [rid] # type: ignore[assignment] + states: List[ReqState] = result # type: ignore[assignment] + outputs = await asyncio.gather( + *(self._wait_for_final_result(r, s) for r, s in zip(rids_list, states)) + ) + return list(outputs) + + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete(_run()) + + async def generate_async( + self, + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + image_data: Optional[Any] = None, + audio_data: Optional[Any] = None, + video_data: Optional[Any] = None, + return_logprob: Optional[Union[List[bool], bool]] = None, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[Union[List[Optional[str]], str]] = None, + session_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, + stream: bool = False, + rid: Optional[Union[List[str], str]] = None, + **kwargs, + ) -> AsyncIterator[Dict[str, Any]]: + """Asynchronous generation entry point. + + For a **single** request and ``stream=False`` yields one final result + dict; with ``stream=True`` yields incremental chunks. + + For a **batch** request the iterator yields the final result for each + sub-request as it completes (order not guaranteed); streaming mode yields + incremental chunks from all sub-requests interleaved. + """ + rid = self._make_rids(rid, prompt, input_ids) + request = GenerateReqInput( + rid=rid, + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + stream=stream, + image_data=image_data, + audio_data=audio_data, + video_data=video_data, + lora_path=lora_path, + session_params=session_params, + extra_options=kwargs, + ) + request.normalize_batch_and_arguments() + result = await self._rr_process.add_request(request) + + if request.is_single: + single_rid = rid if isinstance(rid, str) else rid[0] # type: ignore[index] + state: ReqState = result # type: ignore[assignment] + try: + if stream: + async for chunk in self._stream_results(single_rid, state): + yield chunk + else: + yield await self._wait_for_final_result(single_rid, state) + finally: + if not state.finished: + logger.info("Aborting request %s (client disconnected)", single_rid) + await self._rr_process.abort_request(single_rid) + else: + self._rr_process.remove_state(single_rid) + else: + rids_list: List[str] = rid if isinstance(rid, list) else [rid] # type: ignore[assignment] + states: List[ReqState] = result # type: ignore[assignment] + _bg_tasks: List[asyncio.Task] = [] + try: + if stream: + # Merge streams from all sub-requests using an asyncio queue. + queue: asyncio.Queue = asyncio.Queue() + + async def _forward(r: str, s: ReqState) -> None: + async for chunk in self._stream_results(r, s): + await queue.put(chunk) + await queue.put(None) # sentinel + + _bg_tasks = [ + asyncio.create_task(_forward(r, s)) + for r, s in zip(rids_list, states) + ] + done_count = 0 + while done_count < len(_bg_tasks): + item = await queue.get() + if item is None: + done_count += 1 + else: + yield item + await asyncio.gather(*_bg_tasks) + else: + for coro in asyncio.as_completed( + [ + self._wait_for_final_result(r, s) + for r, s in zip(rids_list, states) + ] + ): + yield await coro + finally: + for t in _bg_tasks: + t.cancel() + for r, s in zip(rids_list, states): + if not s.finished: + logger.info("Aborting request %s (client disconnected)", r) + await self._rr_process.abort_request(r) + else: + self._rr_process.remove_state(r) + + @staticmethod + async def _wait_for_final_result(rid: str, state: ReqState) -> Dict[str, Any]: + """Block until the request is finished and return the last output.""" + while True: + await state.event.wait() + if state.finished: + return state.out_list[-1] + state.event.clear() + + @staticmethod + async def _stream_results( + rid: str, state: ReqState + ) -> AsyncIterator[Dict[str, Any]]: + """Yield incremental chunks as they arrive, until finished.""" + while True: + await state.event.wait() + for item in state.out_list: + yield item + state.out_list.clear() + if state.finished: + return + state.event.clear() + + @staticmethod + def _make_rids( + rid: Optional[Union[str, List[str]]], + prompt: Optional[Union[str, List[str]]], + input_ids: Optional[Union[List[int], List[List[int]]]], + ) -> Union[str, List[str]]: + """Return rids, auto-generating UUIDs when *rid* is ``None``. + + The helper infers whether the call is a batch from *prompt* / *input_ids* + so callers don't have to handle this case themselves. + """ + if rid is not None: + return rid + # Determine batch size from the text/input_ids argument. + is_batch = isinstance(prompt, list) or ( + isinstance(input_ids, list) + and len(input_ids) > 0 + and isinstance(input_ids[0], list) + ) + if is_batch: + n = len(prompt) if prompt is not None else len(input_ids) # type: ignore[arg-type] + return [uuid.uuid4().hex for _ in range(n)] + return uuid.uuid4().hex + + def shutdown(self) -> None: + """Terminate all subprocesses.""" + if self._rr_process is not None: + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(self._rr_process.shutdown()) + else: + loop.run_until_complete(self._rr_process.shutdown()) + except Exception: + pass + for proc in self._subprocesses: + if proc.is_alive(): + proc.terminate() + proc.join(timeout=5) + if proc.is_alive(): + proc.kill() + self._subprocesses.clear() + logger.info("All subprocesses shut down") + + def _set_default_torch_dtype(self): + """Set the default torch dtype based on the server configuration.""" + dtype = get_global_config().server.dtype + if dtype == "auto": + dtype = "bfloat16" if torch.cuda.is_available() else "float32" + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + torch_dtype = dtype_map.get(dtype) + if torch_dtype is None: + raise ValueError(f"Unsupported dtype for torch default dtype: {dtype!r}") + torch.set_default_dtype(torch_dtype) + + def _config_logging(self): + """Configure logging level from server configuration.""" + level_name = get_global_config().server.log_level.upper() + level = getattr(logging, level_name, logging.INFO) + root_logger = logging.getLogger() + if not root_logger.handlers: + logging.basicConfig( + level=level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + else: + root_logger.setLevel(level) + logging.getLogger("pymllm").setLevel(level) + + def _check_model_and_tokenizer(self): + cfg = get_global_config() + if cfg.server.model_path is None or cfg.server.tokenizer_path is None: + logger.error("Model path or tokenizer path is not set") + raise ValueError("Model path or tokenizer path is not set") + model_path = cfg.server.model_path + tokenizer_path = cfg.server.tokenizer_path + download_dir = cfg.server.download_dir + trust_remote_code = cfg.server.trust_remote_code + + shared_path = model_path == tokenizer_path + + model_path = self._maybe_download(model_path, download_dir) + cfg.server.model_path = model_path + + if shared_path: + cfg.server.tokenizer_path = model_path + else: + cfg.server.tokenizer_path = self._maybe_download( + tokenizer_path, download_dir + ) + + cfg.model.hf_config = AutoConfig.from_pretrained( + str(model_path), + trust_remote_code=trust_remote_code, + ) + logger.info("Loaded model config: %s", cfg.model.hf_config.__class__.__name__) + + @staticmethod + def _maybe_download(path: Path, download_dir: Optional[Path] = None) -> Path: + if path.is_dir(): + return path + repo_id = str(path) + logger.info("Downloading '%s' ...", repo_id) + kwargs = {} + if download_dir is not None: + kwargs["local_dir"] = str(download_dir / path.name) + downloaded = snapshot_download(repo_id=repo_id, **kwargs) + logger.info("Downloaded '%s' to '%s'", repo_id, downloaded) + return Path(downloaded) diff --git a/pymllm/executor/__init__.py b/pymllm/executor/__init__.py new file mode 100644 index 000000000..b513b8705 --- /dev/null +++ b/pymllm/executor/__init__.py @@ -0,0 +1,10 @@ +"""Executor module: model loading, forward pass, and sampling.""" + +from pymllm.executor.cuda_graph_runner import CudaGraphRunner +from pymllm.executor.model_runner import LogitsProcessorOutput, ModelRunner + +__all__ = [ + "CudaGraphRunner", + "LogitsProcessorOutput", + "ModelRunner", +] diff --git a/pymllm/executor/cuda_graph_runner.py b/pymllm/executor/cuda_graph_runner.py new file mode 100644 index 000000000..fe4fb0e92 --- /dev/null +++ b/pymllm/executor/cuda_graph_runner.py @@ -0,0 +1,590 @@ +"""CUDA-graph accelerated forward pass for decode steps. + +Captures CUDA graphs for a set of discrete batch sizes so that the decode +forward pass can be replayed without CPU-side kernel-launch overhead. + +Simplified from sglang's ``CudaGraphRunner`` for pymllm's single-GPU +architecture. Handles: + +* Pre-allocated input buffers (avoids per-step allocations) +* CUDA-graph capture for each batch size +* Optional ``torch.compile`` integration +* Graph replay with padding to the nearest captured batch size + +Typical lifecycle:: + + runner = CudaGraphRunner(model_runner) # captures all batch sizes + + # --- inside the inference loop --- + if runner.can_run(forward_batch): + logits_output = runner.replay(forward_batch) + else: + logits_output = model_runner.forward(forward_batch) + +Integration with :class:`~pymllm.executor.model_runner.ModelRunner` +------------------------------------------------------------------- +The ``ModelRunner`` owns the ``CudaGraphRunner`` and delegates decode +batches to it when the batch size is within the captured range. The +``CudaGraphRunner`` calls ``attn_backend.init_forward_metadata_*_cuda_graph`` +directly (bypassing the normal ``init_forward_metadata`` path) so that +FlashInfer's per-batch planning is recorded inside the graph. +""" + +from __future__ import annotations + +import bisect +import gc +import logging +import time +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union + +import torch + +from pymllm.engine.forward_batch import ForwardBatch, ForwardMode +from pymllm.executor.model_runner import LogitsProcessorOutput + +if TYPE_CHECKING: + from pymllm.executor.model_runner import ModelRunner + from pymllm.layers.attention.attention_backend import AttentionBackend + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Global CUDA-graph memory pool (shared across all CudaGraphRunner instances) +# --------------------------------------------------------------------------- + +_global_graph_memory_pool: Optional[tuple] = None + + +def get_global_graph_memory_pool() -> Optional[tuple]: + """Return the shared CUDA graph memory pool handle.""" + return _global_graph_memory_pool + + +def set_global_graph_memory_pool(pool: tuple) -> None: + """Set the shared CUDA graph memory pool handle.""" + global _global_graph_memory_pool + _global_graph_memory_pool = pool + + +# --------------------------------------------------------------------------- +# Context managers +# --------------------------------------------------------------------------- + +# Flag indicating whether we are currently capturing a CUDA graph. +_is_capture_mode: bool = False + + +def is_capture_mode() -> bool: + """Return ``True`` if a CUDA-graph capture is in progress.""" + return _is_capture_mode + + +@contextmanager +def model_capture_mode(): + """Context manager that sets the global capture-mode flag.""" + global _is_capture_mode + _is_capture_mode = True + try: + yield + finally: + _is_capture_mode = False + + +@contextmanager +def freeze_gc(): + """Freeze the garbage collector during CUDA-graph capture. + + GC activity during capture can interfere with the recorded stream + ordering. This context manager collects garbage before capture, + freezes all surviving objects, and unfreezes + re-collects afterwards. + """ + gc.collect() + gc.freeze() + try: + yield + finally: + gc.unfreeze() + gc.collect() + + +# --------------------------------------------------------------------------- +# Pre-allocated input buffers +# --------------------------------------------------------------------------- + + +@dataclass +class _InputBuffers: + """Pre-allocated GPU tensors used as CUDA-graph inputs. + + During graph capture these buffers are used as-is. During replay the + real batch data is copied into the first ``batch_size`` rows while the + remaining padding rows retain their fill values. + """ + + input_ids: torch.Tensor # [max_bs] int64 + req_pool_indices: torch.Tensor # [max_bs] int32 + seq_lens: torch.Tensor # [max_bs] int32 + seq_lens_cpu: torch.Tensor # [max_bs] int32 (CPU) + out_cache_loc: torch.Tensor # [max_bs] int64 + positions: torch.Tensor # [max_bs] int64 + mrope_position_deltas: torch.Tensor # [max_bs] int64 + + @classmethod + def create( + cls, + *, + device: torch.device, + max_bs: int, + seq_len_fill_value: int, + ) -> "_InputBuffers": + """Allocate all buffers for the given maximum batch size.""" + with torch.device(device): + input_ids = torch.zeros((max_bs,), dtype=torch.int64) + req_pool_indices = torch.zeros((max_bs,), dtype=torch.int32) + seq_lens = torch.full((max_bs,), seq_len_fill_value, dtype=torch.int32) + out_cache_loc = torch.zeros((max_bs,), dtype=torch.int64) + positions = torch.zeros((max_bs,), dtype=torch.int64) + mrope_position_deltas = torch.zeros((max_bs,), dtype=torch.int64) + + # seq_lens_cpu must be a real CPU tensor. + seq_lens_cpu = torch.full( + (max_bs,), + seq_len_fill_value, + dtype=torch.int32, + device="cpu", + ) + + return cls( + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + out_cache_loc=out_cache_loc, + positions=positions, + mrope_position_deltas=mrope_position_deltas, + ) + + def populate( + self, + forward_batch: ForwardBatch, + padded_bs: int, + seq_len_fill_value: int, + ) -> None: + """Copy real batch data into the pre-allocated buffers. + + Any padding slots (``[real_bs : padded_bs]``) are filled with safe + defaults so that the captured graph does not access invalid memory. + """ + real_bs = forward_batch.batch_size + + # Reset padding slots when the padded size exceeds the real size. + if padded_bs != real_bs: + self.seq_lens.fill_(seq_len_fill_value) + self.out_cache_loc.zero_() + self.mrope_position_deltas.zero_() + + self.input_ids[:real_bs].copy_(forward_batch.input_ids) + self.req_pool_indices[:real_bs].copy_(forward_batch.req_pool_indices) + self.seq_lens[:real_bs].copy_(forward_batch.seq_lens) + self.out_cache_loc[:real_bs].copy_(forward_batch.out_cache_loc) + self.positions[:real_bs].copy_(forward_batch.positions) + + # Copy M-RoPE position deltas (used by Qwen3-VL for multimodal). + if forward_batch.mrope_position_deltas is not None: + self.mrope_position_deltas[:real_bs].copy_( + forward_batch.mrope_position_deltas + ) + else: + self.mrope_position_deltas[:real_bs].zero_() + + if forward_batch.seq_lens_cpu is not None: + if padded_bs != real_bs: + self.seq_lens_cpu.fill_(seq_len_fill_value) + self.seq_lens_cpu[:real_bs].copy_(forward_batch.seq_lens_cpu) + + +# --------------------------------------------------------------------------- +# Batch-size schedule +# --------------------------------------------------------------------------- + + +def _default_capture_batch_sizes(max_bs: int) -> List[int]: + """Return a list of batch sizes to capture. + + Uses the same schedule as sglang (non-speculative):: + + [1, 2, 4, 8, 12, 16, 24, 32, 40, …, 256, 272, 288, …, 512, 544, …] + + Capped at *max_bs*. + """ + bs_list = ( + [1, 2, 4, 8, 12] + + list(range(16, 257, 8)) + + list(range(272, 512, 16)) + + list(range(512, max_bs + 1, 32)) + ) + bs_list = sorted(set(bs for bs in bs_list if bs <= max_bs)) + if not bs_list: + bs_list = [1] + return bs_list + + +# --------------------------------------------------------------------------- +# CudaGraphRunner +# --------------------------------------------------------------------------- + + +class CudaGraphRunner: + """Captures and replays CUDA graphs for decode-step forward passes. + + This class is the pymllm equivalent of sglang's ``CudaGraphRunner``, + stripped of distributed, speculative-decoding, LoRA, mamba, TBO, and + piecewise-graph complexities. + + Parameters + ---------- + model_runner + The owning :class:`~pymllm.executor.model_runner.ModelRunner`. + Must have been fully initialised before the ``CudaGraphRunner`` + is constructed. + """ + + def __init__(self, model_runner: "ModelRunner"): + self.model_runner = model_runner + self.device = model_runner.device + + self.graphs: Dict[int, torch.cuda.CUDAGraph] = {} + self.output_buffers: Dict[int, LogitsProcessorOutput] = {} + + self.enable_torch_compile: bool = ( + model_runner.server_config.enable_torch_compile + ) + self.torch_compile_max_bs: int = model_runner.server_config.torch_compile_max_bs + + # ----------------------------------------------------------- + # Batch-size schedule + # ----------------------------------------------------------- + max_bs = model_runner.max_running_requests + self.capture_bs: List[int] = _default_capture_batch_sizes(max_bs) + self.compile_bs: List[int] = ( + [bs for bs in self.capture_bs if bs <= self.torch_compile_max_bs] + if self.enable_torch_compile + else [] + ) + self.max_bs: int = max(self.capture_bs) + + logger.info("CUDA graph capture batch sizes: %s", self.capture_bs) + + # ----------------------------------------------------------- + # Attention-backend CUDA-graph state + # ----------------------------------------------------------- + self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs, self.max_bs) + + # Fill value for padded seq_lens so attention kernels don't div-by-0. + self.seq_len_fill_value: int = ( + self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() + ) + + # ----------------------------------------------------------- + # Pre-allocated input buffers + # ----------------------------------------------------------- + self.buffers: _InputBuffers = _InputBuffers.create( + device=torch.device(self.device), + max_bs=self.max_bs, + seq_len_fill_value=self.seq_len_fill_value, + ) + + # ----------------------------------------------------------- + # Optional torch.compile config + # ----------------------------------------------------------- + if self.enable_torch_compile: + _set_torch_compile_config() + + # ----------------------------------------------------------- + # Capture all batch sizes + # ----------------------------------------------------------- + try: + with model_capture_mode(): + self.capture() + except RuntimeError as exc: + raise RuntimeError( + f"CUDA graph capture failed: {exc}\n" + "Possible fixes:\n" + " 1. Reduce --server.mem_fraction_static (e.g. 0.7)\n" + " 2. Reduce --server.max_running_requests\n" + " 3. Disable CUDA graph with --server.disable_cuda_graph\n" + ) from exc + + # ------------------------------------------------------------------ + # Capability check + # ------------------------------------------------------------------ + + def can_run(self, forward_batch: ForwardBatch) -> bool: + """Return ``True`` if the batch can be run via CUDA graph replay. + + The batch must be a decode (or idle) batch whose size does not + exceed the largest captured batch size. + """ + return ( + forward_batch.forward_mode.is_decode_or_idle() + and forward_batch.batch_size <= self.max_bs + ) + + # ------------------------------------------------------------------ + # Capture + # ------------------------------------------------------------------ + + def capture(self) -> None: + """Capture CUDA graphs for every batch size in ``capture_bs``. + + Iterates in reverse order (largest first) so that the GPU memory + pool allocated for the largest graph is reused by smaller ones. + """ + tic = time.perf_counter() + before_mem = _get_avail_mem(self.device) + logger.info("CUDA graph capture begin. avail mem=%.2f GB", before_mem) + + with freeze_gc(): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + for bs in reversed(self.capture_bs): + forward_fn = self._get_forward_fn(bs) + graph, output = self._capture_one_batch_size(bs, forward_fn, stream) + self.graphs[bs] = graph + self.output_buffers[bs] = output + + after_mem = _get_avail_mem(self.device) + logger.info( + "CUDA graph capture end. elapsed=%.2f s, mem usage=%.2f GB, " + "avail mem=%.2f GB", + time.perf_counter() - tic, + before_mem - after_mem, + after_mem, + ) + + def _get_forward_fn(self, bs: int) -> Callable: + """Return the forward callable for the given batch size. + + When ``torch.compile`` is enabled and *bs* is within the compile + threshold, the model's forward method is wrapped with + ``torch.compile``. + """ + model_forward = self.model_runner.model.forward + if self.enable_torch_compile and bs in self.compile_bs: + return torch.compile( + torch.no_grad()(model_forward), + mode="max-autotune-no-cudagraphs", + ) + return model_forward + + def _capture_one_batch_size( + self, + bs: int, + forward: Callable, + stream: torch.cuda.Stream, + ) -> tuple: + """Capture a single CUDA graph for batch size *bs*. + + Steps: + 1. Build a ``ForwardBatch`` from the pre-allocated buffers. + 2. Tell the attention backend to plan for CUDA-graph capture. + 3. Run the forward pass twice for warmup. + 4. Capture the third run into a ``CUDAGraph``. + + Returns ``(graph, output_buffers)``. + """ + buffers = self.buffers + + # Slice pre-allocated buffers to the capture size. + input_ids = buffers.input_ids[:bs] + req_pool_indices = buffers.req_pool_indices[:bs] + seq_lens = buffers.seq_lens[:bs] + seq_lens_cpu = buffers.seq_lens_cpu[:bs] + out_cache_loc = buffers.out_cache_loc[:bs] + positions = buffers.positions[:bs] + mrope_position_deltas = buffers.mrope_position_deltas[:bs] + + # Build ForwardBatch (DECODE mode). + # mrope_position_deltas is set to the static buffer (initially zeros) + # so that the graph captures the ``positions + deltas`` path. During + # replay the buffer is updated with real delta values. + forward_batch = ForwardBatch( + forward_mode=ForwardMode.DECODE, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + seq_lens_sum=int(seq_lens.sum().item()), + seq_lens_cpu=seq_lens_cpu, + positions=positions, + return_logprob=False, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, + mrope_position_deltas=mrope_position_deltas, + ) + + # Tell the attention backend to set up CUDA-graph-aware metadata. + self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( + bs=bs, + num_tokens=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + forward_mode=ForwardMode.DECODE, + ) + + # The single forward-pass function to be captured. + def run_once(): + return forward( + input_ids, + forward_batch.positions, + forward_batch, + ) + + # Warmup (2 eager runs to stabilise cudnn / autotuner / etc.). + for _ in range(2): + torch.cuda.synchronize() + run_once() + + # ----- Capture ----- + global _global_graph_memory_pool + if _global_graph_memory_pool is None: + _global_graph_memory_pool = torch.cuda.graph_pool_handle() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph( + graph, + pool=_global_graph_memory_pool, + stream=stream, + ): + output = run_once() + + return graph, output + + # ------------------------------------------------------------------ + # Replay + # ------------------------------------------------------------------ + + def replay( + self, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + """Replay a captured CUDA graph for the given decode batch. + + The batch is padded to the nearest captured size, inputs are copied + into the pre-allocated buffers, the graph is replayed, and the + output is sliced back to the real batch size. + + Parameters + ---------- + forward_batch + The decode batch from the scheduler. + + Returns + ------- + LogitsProcessorOutput + The logits for the real (un-padded) sequences. + """ + real_bs = forward_batch.batch_size + + # Find the smallest captured bs >= real_bs. + idx = bisect.bisect_left(self.capture_bs, real_bs) + padded_bs = self.capture_bs[idx] + + # Copy real data into the static buffers. + self.buffers.populate( + forward_batch, + padded_bs=padded_bs, + seq_len_fill_value=self.seq_len_fill_value, + ) + + # Update the attention backend for replay. + seq_lens_sum = ( + forward_batch.seq_lens_sum + (padded_bs - real_bs) * self.seq_len_fill_value + ) + self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( + bs=padded_bs, + req_pool_indices=self.buffers.req_pool_indices[:padded_bs], + seq_lens=self.buffers.seq_lens[:padded_bs], + seq_lens_sum=seq_lens_sum, + forward_mode=ForwardMode.DECODE, + seq_lens_cpu=self.buffers.seq_lens_cpu[:padded_bs], + ) + + # Replay the graph. + self.graphs[padded_bs].replay() + + # Retrieve output and slice to real batch size. + output = self.output_buffers[padded_bs] + + if isinstance(output, LogitsProcessorOutput): + return LogitsProcessorOutput( + next_token_logits=output.next_token_logits[:real_bs], + hidden_states=( + output.hidden_states[:real_bs] + if output.hidden_states is not None + else None + ), + ) + elif isinstance(output, torch.Tensor): + # Raw tensor output: assume [padded_bs, vocab_size]. + return LogitsProcessorOutput( + next_token_logits=output[:real_bs], + ) + else: + # HuggingFace-style output with .logits attribute. + if hasattr(output, "logits"): + logits = output.logits + if logits.dim() == 3: + return LogitsProcessorOutput( + next_token_logits=logits[:real_bs, -1, :], + ) + return LogitsProcessorOutput( + next_token_logits=logits[:real_bs], + ) + raise TypeError(f"Unexpected CUDA graph output type: {type(output)}") + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def shutdown(self) -> None: + """Release all captured CUDA graphs and associated buffers.""" + for graph in self.graphs.values(): + del graph + self.graphs.clear() + self.output_buffers.clear() + logger.info("CudaGraphRunner shutdown complete.") + + +# --------------------------------------------------------------------------- +# Utility helpers +# --------------------------------------------------------------------------- + + +def _get_avail_mem(device: str) -> float: + """Return available GPU memory in GB.""" + if device != "cuda" or not torch.cuda.is_available(): + return 0.0 + free, _ = torch.cuda.mem_get_info() + return free / (1 << 30) + + +def _set_torch_compile_config() -> None: + """Set dynamo / inductor configs for optimal CUDA-graph + compile.""" + try: + import torch._dynamo.config + import torch._inductor.config + + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.fx_graph_cache = True + torch._dynamo.config.accumulated_cache_size_limit = 1024 + if hasattr(torch._dynamo.config, "cache_size_limit"): + torch._dynamo.config.cache_size_limit = 1024 + except ImportError: + logger.warning("torch._dynamo / torch._inductor not available.") diff --git a/pymllm/executor/model_runner.py b/pymllm/executor/model_runner.py new file mode 100644 index 000000000..6d6f33fea --- /dev/null +++ b/pymllm/executor/model_runner.py @@ -0,0 +1,1198 @@ +"""ModelRunner runs the forward passes of the models. + +Simplified from sglang's ``ModelRunner`` for pymllm's single-GPU inference +architecture. Handles: + +* Model loading (HuggingFace checkpoint via ``transformers``) +* KV-cache memory pool initialisation +* Attention backend setup (FlashInfer) +* Forward pass dispatch (extend / decode / idle) +* Token sampling from logits + +Typical lifecycle:: + + runner = ModelRunner(server_config, model_config) + runner.initialize() + + # --- inside the inference loop --- + forward_batch = runner.prepare_forward_batch_decode(...) + logits_output = runner.forward(forward_batch) + next_token_ids = runner.sample(logits_output, forward_batch) + +Typical data flow +----------------- + SchedulerProcess builds a batch dict + ↓ + ModelRunnerProcess calls ModelRunner.forward(forward_batch) + ↓ + attn_backend.init_forward_metadata(forward_batch) + ↓ + model.forward(input_ids, positions, forward_batch) + ↓ + ModelRunner.sample(logits_output, forward_batch) + ↓ + next_token_ids returned to scheduler +""" + +from __future__ import annotations + +import gc +import logging +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union + +import torch +from torch import nn + +from pymllm.configs import get_global_config +from pymllm.engine.forward_batch import ForwardBatch, ForwardMode +from pymllm.mem_cache.memory_pool import ( + GDNPool, + KVPool, + ReqToTokenPool, + TokenToKVPoolAllocator, + make_full_attention_net_mem_pool, + make_req_to_token_pool, +) + +if TYPE_CHECKING: + from pymllm.configs.model_config import ModelConfig + from pymllm.configs.server_config import ServerConfig + from pymllm.executor.cuda_graph_runner import CudaGraphRunner + from pymllm.layers.attention.attention_backend import AttentionBackend + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Utility: GPU memory query +# --------------------------------------------------------------------------- + + +def get_available_gpu_memory(device: str = "cuda", gpu_id: int = 0) -> float: + """Return available GPU memory in GB.""" + if device != "cuda" or not torch.cuda.is_available(): + return 0.0 + torch.cuda.set_device(gpu_id) + free, _ = torch.cuda.mem_get_info(gpu_id) + return free / (1 << 30) + + +def get_total_gpu_memory(device: str = "cuda", gpu_id: int = 0) -> float: + """Return total GPU memory in GB.""" + if device != "cuda" or not torch.cuda.is_available(): + return 0.0 + torch.cuda.set_device(gpu_id) + _, total = torch.cuda.mem_get_info(gpu_id) + return total / (1 << 30) + + +# --------------------------------------------------------------------------- +# LogitsProcessorOutput +# --------------------------------------------------------------------------- + + +@dataclass +class LogitsProcessorOutput: + """Container for output logits produced by the model's forward pass. + + Attributes + ---------- + next_token_logits + Raw logits for the last token of each sequence in the batch, + shape ``[batch_size, vocab_size]``. + hidden_states + Optional hidden states from the model (e.g. for speculative decoding + or auxiliary loss computation). + """ + + next_token_logits: torch.Tensor # [batch_size, vocab_size] + hidden_states: Optional[torch.Tensor] = None + + +# --------------------------------------------------------------------------- +# ModelRunner +# --------------------------------------------------------------------------- + + +class ModelRunner: + """Runs the forward passes of the models. + + This is the core execution component that owns the model, memory pools, + and attention backend. It is used by + :class:`~pymllm.orchestrator.model_runner_process.ModelRunnerProcess` to + execute batches dispatched by the scheduler. + + Parameters + ---------- + server_config + Server runtime configuration. Falls back to the global singleton + when ``None``. + model_config + Model configuration (wraps a HuggingFace ``PretrainedConfig``). + Falls back to the global singleton when ``None``. + gpu_id + GPU device index to use. + """ + + def __init__( + self, + server_config: Optional["ServerConfig"] = None, + model_config: Optional["ModelConfig"] = None, + gpu_id: int = 0, + ): + cfg = get_global_config() + self.server_config = server_config or cfg.server + self.model_config = model_config or cfg.model + + self.gpu_id = gpu_id + self.device: str = "cuda" if torch.cuda.is_available() else "cpu" + self.dtype: torch.dtype = self._resolve_dtype() + + # Set by initialize() + self.model: Optional[nn.Module] = None + self.req_to_token_pool: Optional[ReqToTokenPool] = None + self.token_to_kv_pool: Optional[KVPool] = None + self.token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None + self.gdn_pool: Optional[GDNPool] = None + self.attn_backend: Optional["AttentionBackend"] = None + self.graph_runner: Optional["CudaGraphRunner"] = None + + # Memory configuration + self.max_total_num_tokens: int = 0 + self.max_running_requests: int = 0 + + # Model metadata (populated after loading) + self.num_hidden_layers: int = 0 + self.num_attention_heads: int = 0 + self.num_kv_heads: int = 0 + self.head_dim: int = 0 + self.hidden_size: int = 0 + self.vocab_size: int = 0 + self.context_len: int = 0 + + # KV cache dtype -- same as model dtype by default; may differ for + # quantised KV caches in the future. + self.kv_cache_dtype: torch.dtype = self.dtype + + # Forward pass counter (monotonically increasing). + self.forward_pass_id: int = 0 + + # ------------------------------------------------------------------ + # Initialisation + # ------------------------------------------------------------------ + + def initialize(self) -> None: + """Full initialisation: set device, load model, init memory + backend. + + Call this once before any forward pass. + """ + tic = time.perf_counter() + logger.info("ModelRunner initialisation begin.") + + # Set device + if self.device == "cuda": + torch.cuda.set_device(self.gpu_id) + + # Set default dtype + torch.set_default_dtype(self.dtype) + + # Load the model + self.load_model() + + # Extract model metadata from hf_config + self._extract_model_metadata() + + # Resolve KV-cache dtype + self._configure_kv_cache_dtype() + + # Initialise memory pools + self.init_memory_pool() + + # Initialise attention backend + self.init_attention_backend() + + # Warm up cuBLAS + if self.device == "cuda": + self._init_cublas() + + # Capture CUDA graphs (must be after model + pools + backend) + self.init_cuda_graphs() + + elapsed = time.perf_counter() - tic + logger.info( + "ModelRunner initialisation complete. elapsed=%.2f s, " + "device=%s, dtype=%s, kv_dtype=%s, max_tokens=%d, max_reqs=%d", + elapsed, + self.device, + self.dtype, + self.kv_cache_dtype, + self.max_total_num_tokens, + self.max_running_requests, + ) + + # ------------------------------------------------------------------ + # Dtype resolution + # ------------------------------------------------------------------ + + def _resolve_dtype(self) -> torch.dtype: + """Resolve the model dtype from configuration.""" + dtype_str = self.server_config.dtype + if dtype_str == "auto": + if torch.cuda.is_available(): + if torch.cuda.get_device_capability()[0] >= 8: + return torch.bfloat16 + return torch.float16 + return torch.float32 + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + result = dtype_map.get(dtype_str) + if result is None: + raise ValueError(f"Unsupported dtype: {dtype_str!r}") + return result + + def _configure_kv_cache_dtype(self) -> None: + """Determine the dtype used for KV-cache storage. + + The global ``QuantizationConfig.kv_cache_dtype`` can override the + model dtype (e.g. ``fp8_e4m3`` for quantised KV caches). When set + to ``"auto"`` the model dtype is used as-is. + """ + cfg = get_global_config() + kv_dtype_str = cfg.quantization.kv_cache_dtype + + if kv_dtype_str == "auto": + self.kv_cache_dtype = self.dtype + return + + kv_dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "fp8_e4m3": torch.float8_e4m3fn, + "fp8_e5m2": torch.float8_e5m2, + } + resolved = kv_dtype_map.get(kv_dtype_str) + if resolved is None: + logger.warning( + "Unrecognised kv_cache_dtype %r, falling back to model dtype.", + kv_dtype_str, + ) + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = resolved + + logger.info("KV-cache dtype: %s", self.kv_cache_dtype) + + # ------------------------------------------------------------------ + # Model metadata + # ------------------------------------------------------------------ + + def _extract_model_metadata(self) -> None: + """Extract key model parameters from the HuggingFace config.""" + hf_config = self.model_config.hf_config + if hf_config is None: + raise RuntimeError( + "HuggingFace config not loaded. " + "Make sure model_config.hf_config is set before calling " + "initialize()." + ) + + # Handle text_config for multimodal models + text_config = getattr(hf_config, "text_config", hf_config) + + self.num_hidden_layers = getattr(text_config, "num_hidden_layers", 0) + self.num_attention_heads = getattr(text_config, "num_attention_heads", 0) + self.num_kv_heads = getattr( + text_config, + "num_key_value_heads", + self.num_attention_heads, + ) + self.head_dim = getattr( + text_config, + "head_dim", + getattr(text_config, "hidden_size", 0) // max(self.num_attention_heads, 1), + ) + self.hidden_size = getattr(text_config, "hidden_size", 0) + self.vocab_size = getattr(text_config, "vocab_size", 0) + + # V-head dim may differ from K-head dim (e.g. MLA) + self.v_head_dim: int = getattr(text_config, "v_head_dim", self.head_dim) + + # Context length + self.context_len = self.server_config.context_length or getattr( + text_config, "max_position_embeddings", 4096 + ) + + # Hybrid model metadata (GDN layers) + self.num_gdn_layers: int = getattr(self.model, "num_gdn_layers", 0) + self.full_attn_layer_ids: set = getattr(self.model, "full_attn_layer_ids", set()) + + logger.info( + "Model metadata: layers=%d, q_heads=%d, kv_heads=%d, " + "head_dim=%d, v_head_dim=%d, hidden=%d, vocab=%d, ctx_len=%d" + + (", gdn_layers=%d" if self.num_gdn_layers > 0 else ""), + self.num_hidden_layers, + self.num_attention_heads, + self.num_kv_heads, + self.head_dim, + self.v_head_dim, + self.hidden_size, + self.vocab_size, + self.context_len, + *([self.num_gdn_layers] if self.num_gdn_layers > 0 else []), + ) + + # ------------------------------------------------------------------ + # Model loading + # ------------------------------------------------------------------ + + def load_model(self) -> None: + """Load the model from a HuggingFace checkpoint. + + First checks the pymllm model registry for a custom implementation + that uses ``RadixAttention``. If found, instantiates it with the + HuggingFace config and loads weights via ``load_weights()``. + Otherwise falls back to ``AutoModelForCausalLM.from_pretrained``. + """ + tic = time.perf_counter() + model_path = self.server_config.model_path + + if model_path is None: + raise RuntimeError("server_config.model_path is not set.") + + before_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + "Load model begin. path=%s, avail mem=%.2f GB", + model_path, + before_mem, + ) + + # Look up the architecture in the pymllm model registry + from pymllm.models import _MODEL_REGISTRY, get_model_class + + hf_config = self.model_config.hf_config + architectures = [] + if hf_config is not None: + architectures = getattr(hf_config, "architectures", None) or [] + + if not architectures: + supported = ", ".join(sorted(_MODEL_REGISTRY.keys())) + raise RuntimeError( + f"Cannot determine model architecture from config. " + f"Supported architectures: {supported}" + ) + + architecture = architectures[0] + model_cls = get_model_class(architecture) + if model_cls is None: + supported = ", ".join(sorted(_MODEL_REGISTRY.keys())) + raise RuntimeError( + f"Architecture {architecture!r} is not supported by pymllm. " + f"Supported architectures: {supported}" + ) + + logger.info("Using pymllm model class: %s", model_cls.__name__) + device_str = f"cuda:{self.gpu_id}" if self.device == "cuda" else self.device + # Use set_default_dtype so parameters created without explicit dtype + # get the target dtype, while parameters with explicit dtype=torch.float32 + # (e.g. A_log, dt_bias in GDN layers) stay in float32. + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(self.dtype) + try: + with torch.device(device_str): + self.model = model_cls(hf_config) + finally: + torch.set_default_dtype(old_dtype) + self.model.load_weights(self._iter_weights(model_path)) + self.model.eval() + + after_mem = get_available_gpu_memory(self.device, self.gpu_id) + weight_mem = before_mem - after_mem + logger.info( + "Load model end. elapsed=%.2f s, type=%s, " + "weight_mem=%.2f GB, avail mem=%.2f GB", + time.perf_counter() - tic, + type(self.model).__name__, + weight_mem, + after_mem, + ) + + @staticmethod + def _iter_weights(model_path) -> "Generator[Tuple[str, torch.Tensor], None, None]": + """Yield ``(name, tensor)`` pairs from safetensors or ``.bin`` files. + + Prefers safetensors when available; falls back to PyTorch ``.bin`` + files otherwise. + """ + import glob as _glob + from pathlib import Path + + model_path = Path(model_path) + + # Prefer safetensors + st_files = sorted(_glob.glob(str(model_path / "*.safetensors"))) + if st_files: + from safetensors.torch import load_file + + for fpath in st_files: + state_dict = load_file(fpath) + yield from state_dict.items() + del state_dict + return + + # Fallback: PyTorch .bin files + bin_files = sorted(_glob.glob(str(model_path / "*.bin"))) + for fpath in bin_files: + state_dict = torch.load(fpath, map_location="cpu", weights_only=True) + yield from state_dict.items() + del state_dict + + # ------------------------------------------------------------------ + # Memory pool initialisation + # ------------------------------------------------------------------ + + def init_memory_pool(self) -> None: + """Initialise KV-cache memory pools and request-to-token mapping. + + 1. Profiles available GPU memory to determine the maximum number of + KV-cache token slots (``max_total_num_tokens``). + 2. Derives ``max_running_requests`` from config or heuristic. + 3. Creates :class:`~pymllm.mem_cache.memory_pool.ReqToTokenPool`, + :class:`~pymllm.mem_cache.memory_pool.KVPool`, and + :class:`~pymllm.mem_cache.memory_pool.TokenToKVPoolAllocator`. + """ + logger.info("Initialising memory pools...") + + # Determine max number of tokens in KV cache + self.max_total_num_tokens = self._profile_max_num_tokens() + + # Determine max running requests + max_reqs = self.server_config.max_running_requests + if max_reqs is None: + max_reqs = min( + max( + int(self.max_total_num_tokens / self.context_len * 512), + 2048, + ), + 4096, + ) + self.max_running_requests = max_reqs + + if self.max_total_num_tokens <= 0: + raise RuntimeError( + "Not enough memory for KV cache. " + "Try reducing context_length or using a smaller model." + ) + + # Create ReqToTokenPool + self.req_to_token_pool = make_req_to_token_pool( + max_reqs=self.max_running_requests, + max_context_len=self.context_len + 4, # small padding + device=self.device, + ) + + # Create KVPool + TokenToKVPoolAllocator + # Note: layer_num uses num_hidden_layers even for hybrid models + # because the KV pool is indexed by global layer_id. GDN layers' + # KV slots are allocated but unused (they use GDNPool instead). + self.token_to_kv_pool, self.token_to_kv_pool_allocator = ( + make_full_attention_net_mem_pool( + size=self.max_total_num_tokens, + layer_num=self.num_hidden_layers, + k_head_num=self.num_kv_heads, + k_head_dim=self.head_dim, + v_head_num=self.num_kv_heads, + v_head_dim=self.v_head_dim, + device=self.device, + dtype=self.kv_cache_dtype, + ) + ) + + # Create GDNPool if hybrid model with GDN layers + if self.num_gdn_layers > 0: + hf_config = self.model_config.hf_config + text_config = getattr(hf_config, "text_config", hf_config) + gdn_num_k_heads = getattr(text_config, "linear_num_key_heads", 16) + gdn_num_v_heads = getattr(text_config, "linear_num_value_heads", 32) + gdn_head_k_dim = getattr(text_config, "linear_key_head_dim", 128) + gdn_head_v_dim = getattr(text_config, "linear_value_head_dim", 128) + gdn_conv_kernel = getattr(text_config, "linear_conv_kernel_dim", 4) + gdn_conv_dim = gdn_num_k_heads * gdn_head_k_dim * 2 + gdn_num_v_heads * gdn_head_v_dim + + self.gdn_pool = GDNPool( + max_reqs=self.max_running_requests, + num_gdn_layers=self.num_gdn_layers, + num_v_heads=gdn_num_v_heads, + head_k_dim=gdn_head_k_dim, + head_v_dim=gdn_head_v_dim, + conv_dim=gdn_conv_dim, + conv_kernel_size=gdn_conv_kernel, + device=self.device, + dtype=self.dtype, + max_track_slots=self.max_running_requests, + ) + + logger.info( + "Memory pool initialised: max_tokens=%d, max_reqs=%d, kv_pool=%.2f GB" + + (", gdn_pool=%.2f GB" if self.gdn_pool is not None else ""), + self.max_total_num_tokens, + self.max_running_requests, + self.token_to_kv_pool._mem_bytes() / (1 << 30), + *([self.gdn_pool.mem_bytes() / (1 << 30)] if self.gdn_pool is not None else []), + ) + + def _profile_max_num_tokens(self) -> int: + """Profile available memory to determine maximum KV-cache tokens. + + If ``server_config.max_total_tokens`` is explicitly set that value + is used directly. Otherwise a memory-fraction-based heuristic + similar to sglang's ``profile_max_num_token`` is applied. + """ + # If user explicitly set max_total_tokens, use that. + if self.server_config.max_total_tokens is not None: + return self.server_config.max_total_tokens + + if self.device != "cuda": + # For CPU, use a conservative default. + return 4096 + + available_gb = get_available_gpu_memory(self.device, self.gpu_id) + + # Determine memory fraction for static allocation (KV cache). + mem_fraction = self.server_config.mem_fraction_static + if mem_fraction is None: + mem_fraction = 0.85 # default: use 85% of remaining memory + + # Calculate per-token KV cache size in bytes. + kv_element_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() + cell_size = ( + self.num_kv_heads + * (self.head_dim + self.v_head_dim) # K + V + * self.num_hidden_layers + * kv_element_size + ) + + if cell_size == 0: + logger.warning( + "cell_size is 0 (model metadata may be incomplete); " + "using default max_total_num_tokens=4096" + ) + return 4096 + + rest_memory_bytes = int(available_gb * mem_fraction * (1 << 30)) + + # Reserve memory for GDN pool if hybrid model + if self.num_gdn_layers > 0: + hf_config = self.model_config.hf_config + text_config = getattr(hf_config, "text_config", hf_config) + gdn_num_k_heads = getattr(text_config, "linear_num_key_heads", 16) + gdn_num_v_heads = getattr(text_config, "linear_num_value_heads", 32) + gdn_head_k_dim = getattr(text_config, "linear_key_head_dim", 128) + gdn_head_v_dim = getattr(text_config, "linear_value_head_dim", 128) + gdn_conv_kernel = getattr(text_config, "linear_conv_kernel_dim", 4) + gdn_conv_dim = gdn_num_k_heads * gdn_head_k_dim * 2 + gdn_num_v_heads * gdn_head_v_dim + + # Estimate GDN pool memory for max_running_requests + # Track slots add max_reqs_est extra slots for prefix cache snapshots + max_reqs_est = min( + max(int(rest_memory_bytes / cell_size / self.context_len * 512), 2048), + 4096, + ) if self.server_config.max_running_requests is None else self.server_config.max_running_requests + pool_size = max_reqs_est + 1 + max_reqs_est # +track_slots + recurrent_bytes = ( + self.num_gdn_layers * pool_size * gdn_num_v_heads + * gdn_head_v_dim * gdn_head_k_dim * 4 # float32 + ) + dtype_size = torch.tensor([], dtype=self.dtype).element_size() + conv_bytes = ( + self.num_gdn_layers * pool_size * gdn_conv_dim + * (gdn_conv_kernel - 1) * dtype_size + ) + gdn_pool_bytes = recurrent_bytes + conv_bytes + rest_memory_bytes -= gdn_pool_bytes + logger.info( + "GDN pool memory reservation: %.2f GB", + gdn_pool_bytes / (1 << 30), + ) + + max_num_tokens = rest_memory_bytes // cell_size + + logger.info( + "Memory profiling: avail=%.2f GB, fraction=%.2f, " + "cell_size=%d bytes, max_tokens=%d", + available_gb, + mem_fraction, + cell_size, + max_num_tokens, + ) + + return max(max_num_tokens, 1) # at least 1 + + # ------------------------------------------------------------------ + # Attention backend + # ------------------------------------------------------------------ + + def init_attention_backend(self) -> None: + """Initialise the attention backend. + + Creates a :class:`FlashInferAttnBackend` for standard models, or a + :class:`HybridAttnBackend` (FlashInfer + GDN) for hybrid models. + """ + from pymllm.layers.attention.flashinfer_backend import FlashInferAttnBackend + + logger.info("Initialising attention backend...") + + flash_backend = FlashInferAttnBackend( + num_heads=self.num_attention_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + kv_cache_dtype=self.kv_cache_dtype, + q_dtype=self.dtype, + max_context_len=self.context_len, + req_to_token=self.req_to_token_pool.req_to_token, + device=torch.device(self.device), + max_req_pool_size=self.req_to_token_pool.size, + ) + + if self.gdn_pool is not None: + from pymllm.layers.attention.gdn_backend import GDNAttnBackend + from pymllm.layers.attention.hybrid_backend import HybridAttnBackend + + gdn_backend = GDNAttnBackend( + gdn_pool=self.gdn_pool, + device=torch.device(self.device), + ) + self.attn_backend = HybridAttnBackend( + full_attn_backend=flash_backend, + gdn_backend=gdn_backend, + full_attn_layer_ids=self.full_attn_layer_ids, + ) + else: + self.attn_backend = flash_backend + + logger.info( + "Attention backend: %s", + type(self.attn_backend).__name__, + ) + + # ------------------------------------------------------------------ + # Warmup + # ------------------------------------------------------------------ + + def _init_cublas(self) -> None: + """Run a small matmul to initialise cuBLAS. + + Without this, the first real matmul may incur a significant + initialisation overhead. + """ + dtype = torch.float16 + device = "cuda" + a = torch.ones((16, 16), dtype=dtype, device=device) + b = torch.ones((16, 16), dtype=dtype, device=device) + _ = a @ b + + # ------------------------------------------------------------------ + # CUDA graph capture + # ------------------------------------------------------------------ + + def init_cuda_graphs(self) -> None: + """Capture CUDA graphs for decode-step acceleration. + + Skipped when: + * The device is not CUDA. + * ``server_config.disable_cuda_graph`` is ``True``. + * The model is not a generation model. + """ + self.graph_runner = None + + if self.device != "cuda": + return + if self.server_config.disable_cuda_graph: + logger.info("CUDA graphs disabled by config.") + return + if not self.is_generation: + return + + from pymllm.executor.cuda_graph_runner import CudaGraphRunner + + tic = time.perf_counter() + before_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info("Capturing CUDA graphs... avail mem=%.2f GB", before_mem) + + self.graph_runner = CudaGraphRunner(self) + + after_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + "CUDA graph capture complete. elapsed=%.2f s, " + "mem usage=%.2f GB, avail mem=%.2f GB", + time.perf_counter() - tic, + before_mem - after_mem, + after_mem, + ) + + # ------------------------------------------------------------------ + # ForwardBatch construction + # ------------------------------------------------------------------ + + def prepare_forward_batch_extend( + self, + input_ids: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + extend_seq_lens: torch.Tensor, + extend_prefix_lens: torch.Tensor, + out_cache_loc: torch.Tensor, + return_logprob: bool = False, + top_logprobs_nums: Optional[List[int]] = None, + ) -> ForwardBatch: + """Build a :class:`ForwardBatch` for an extend (prefill) pass. + + Parameters + ---------- + input_ids + Token IDs for all new tokens, shape ``[total_new_tokens]``. + req_pool_indices + Index of each request in ``ReqToTokenPool``, + shape ``[batch_size]``. + seq_lens + Total (prefix + new) length of each sequence, + shape ``[batch_size]``. + extend_seq_lens + Number of new tokens per sequence, shape ``[batch_size]``. + extend_prefix_lens + Cached prefix length per sequence, shape ``[batch_size]``. + out_cache_loc + KV-pool slot indices for each new token, + shape ``[total_new_tokens]``. + return_logprob + Whether to return per-token log-probabilities. + top_logprobs_nums + Number of top log-probs per sequence. + """ + batch_size = req_pool_indices.shape[0] + seq_lens_sum = int(seq_lens.sum().item()) + extend_num_tokens = int(extend_seq_lens.sum().item()) + + # Compute positions for each token + positions = _compute_positions(extend_seq_lens, extend_prefix_lens) + + # Compute extend_start_loc (exclusive cumsum of extend_seq_lens) + extend_start_loc = torch.zeros( + batch_size, dtype=torch.int32, device=self.device + ) + if batch_size > 1: + extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0).to( + torch.int32 + ) + + return ForwardBatch( + forward_mode=ForwardMode.EXTEND, + batch_size=batch_size, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens_sum, + seq_lens_cpu=seq_lens.cpu(), + positions=positions, + extend_num_tokens=extend_num_tokens, + extend_seq_lens=extend_seq_lens, + extend_prefix_lens=extend_prefix_lens, + extend_start_loc=extend_start_loc, + extend_prefix_lens_cpu=extend_prefix_lens.tolist(), + extend_seq_lens_cpu=extend_seq_lens.tolist(), + return_logprob=return_logprob, + top_logprobs_nums=top_logprobs_nums, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool=self.token_to_kv_pool, + attn_backend=self.attn_backend, + ) + + def prepare_forward_batch_decode( + self, + input_ids: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + out_cache_loc: torch.Tensor, + return_logprob: bool = False, + top_logprobs_nums: Optional[List[int]] = None, + mrope_position_deltas: Optional[torch.Tensor] = None, + ) -> ForwardBatch: + """Build a :class:`ForwardBatch` for a decode step. + + Parameters + ---------- + input_ids + Token IDs (one per sequence), shape ``[batch_size]``. + req_pool_indices + Index of each request in ``ReqToTokenPool``, + shape ``[batch_size]``. + seq_lens + Total sequence length of each request, shape ``[batch_size]``. + out_cache_loc + KV-pool slot for each sequence's new token, + shape ``[batch_size]``. + return_logprob + Whether to return per-token log-probabilities. + top_logprobs_nums + Number of top log-probs per sequence. + mrope_position_deltas + Per-request M-RoPE position deltas, shape ``[batch_size]`` (int64). + Used by multimodal models (e.g. Qwen3-VL) to offset decode-step + positions by the spatial extent of prefill images. + """ + batch_size = req_pool_indices.shape[0] + seq_lens_sum = int(seq_lens.sum().item()) + + # For decode, positions = seq_lens - 1 (the new token position) + positions = (seq_lens - 1).to(torch.int64) + + return ForwardBatch( + forward_mode=ForwardMode.DECODE, + batch_size=batch_size, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens_sum, + seq_lens_cpu=seq_lens.cpu(), + positions=positions, + return_logprob=return_logprob, + top_logprobs_nums=top_logprobs_nums, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool=self.token_to_kv_pool, + attn_backend=self.attn_backend, + mrope_position_deltas=mrope_position_deltas, + ) + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + def forward( + self, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + """Run a forward pass through the model. + + Dispatches to the appropriate method based on the batch's + :attr:`~pymllm.engine.forward_batch.ForwardMode`. For decode + batches, automatically uses CUDA-graph replay when a captured + graph is available. + + Parameters + ---------- + forward_batch + The prepared batch (from ``prepare_forward_batch_*``). + + Returns + ------- + LogitsProcessorOutput + Contains ``next_token_logits`` of shape + ``[batch_size, vocab_size]``. + """ + self.forward_pass_id += 1 + + if forward_batch.forward_mode.is_idle(): + return self._forward_idle(forward_batch) + + # Try CUDA graph replay for decode batches. + if ( + forward_batch.forward_mode.is_decode() + and self.graph_runner is not None + and self.graph_runner.can_run(forward_batch) + ): + return self.graph_runner.replay(forward_batch) + + if forward_batch.forward_mode.is_decode(): + return self.forward_decode(forward_batch) + elif forward_batch.forward_mode.is_extend(): + return self.forward_extend(forward_batch) + else: + raise ValueError(f"Unsupported forward mode: {forward_batch.forward_mode}") + + def forward_decode( + self, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + """Run a decode forward pass (one new token per sequence). + + Calls ``attn_backend.init_forward_metadata`` followed by + ``model.forward``. + """ + self.attn_backend.init_forward_metadata(forward_batch) + model_output = self.model.forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + ) + return self._process_logits(model_output, forward_batch) + + def forward_extend( + self, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + """Run an extend (prefill) forward pass. + + Calls ``attn_backend.init_forward_metadata`` followed by + ``model.forward``. + """ + self.attn_backend.init_forward_metadata(forward_batch) + model_output = self.model.forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + ) + return self._process_logits(model_output, forward_batch) + + def _forward_idle( + self, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + """Return empty logits for an idle batch (no sequences to process).""" + return LogitsProcessorOutput( + next_token_logits=torch.empty( + (0, self.vocab_size), + dtype=self.dtype, + device=self.device, + ), + ) + + # ------------------------------------------------------------------ + # Logits post-processing + # ------------------------------------------------------------------ + + def _process_logits( + self, + model_output: Any, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + """Extract last-token logits from model output. + + Handles: + * A :class:`LogitsProcessorOutput` returned by custom model + implementations. + * A ``CausalLMOutput`` (from HuggingFace ``transformers``) with a + ``.logits`` attribute. + * A raw ``torch.Tensor`` of logits. + """ + if isinstance(model_output, LogitsProcessorOutput): + return model_output + + # Standard HuggingFace output + if hasattr(model_output, "logits"): + logits = model_output.logits + elif isinstance(model_output, torch.Tensor): + logits = model_output + else: + raise TypeError( + f"Unexpected model output type: {type(model_output)}. " + "Expected torch.Tensor or an object with .logits attribute." + ) + + # --- Decode: logits is [bs, 1, vocab] or [bs, vocab] --- + if forward_batch.forward_mode.is_decode(): + if logits.dim() == 3: + next_token_logits = logits[:, -1, :] + else: + next_token_logits = logits + else: + # --- Extend: pick the last token of each sequence --- + next_token_logits = self._gather_last_token_logits(logits, forward_batch) + + return LogitsProcessorOutput(next_token_logits=next_token_logits) + + def _gather_last_token_logits( + self, + logits: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + """Gather the logits of the last token in each sequence for extend. + + During extend, the model processes all tokens but we only need the + logits at the last position of each sequence for next-token sampling. + """ + if logits.dim() == 3: + # [batch_size, seq_len, vocab_size] from standard HF model + return logits[:, -1, :] + + # Flat layout [total_tokens, vocab_size] + if ( + forward_batch.extend_start_loc is not None + and forward_batch.extend_seq_lens is not None + ): + last_indices = ( + forward_batch.extend_start_loc + forward_batch.extend_seq_lens - 1 + ).long() + return logits[last_indices] + + # Fallback: last row + return logits[-1:, :] + + # ------------------------------------------------------------------ + # Sampling + # ------------------------------------------------------------------ + + def sample( + self, + logits_output: LogitsProcessorOutput, + forward_batch: ForwardBatch, + temperatures: Optional[torch.Tensor] = None, + top_ps: Optional[torch.Tensor] = None, + top_ks: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Sample next-token IDs from logits. + + Supports per-request temperature, top-p, and top-k. + + Parameters + ---------- + logits_output + The logits from :meth:`forward`. + forward_batch + The current forward batch. + temperatures + Per-request temperature, shape ``[batch_size]``. + top_ps + Per-request top-p, shape ``[batch_size]``. + top_ks + Per-request top-k, shape ``[batch_size]``. + + Returns + ------- + torch.Tensor + Next-token IDs, shape ``[batch_size]``, dtype ``int32``. + """ + from pymllm.layers.sampling import ( + sampling_from_probs, + softmax, + top_k_top_p_sampling_from_probs, + ) + + logits = logits_output.next_token_logits + + if logits.numel() == 0: + return torch.empty(0, dtype=torch.int32, device=self.device) + + # Greedy path: temperature=0 (or all zeros) → argmax, no sampling. + if temperatures is not None: + all_greedy = bool((temperatures < 1e-6).all()) + else: + all_greedy = False + + if all_greedy: + return logits.argmax(dim=-1).to(torch.int32) + + # Stochastic path: apply temperature then sample. + if temperatures is not None: + probs = softmax(logits, temperature=temperatures) + else: + probs = torch.softmax(logits.float(), dim=-1) + + # Apply top-k / top-p sampling if specified + has_top_k = top_ks is not None + has_top_p = top_ps is not None + + if has_top_k or has_top_p: + k = top_ks if has_top_k else logits.shape[-1] + p = top_ps if has_top_p else 1.0 + next_token_ids = top_k_top_p_sampling_from_probs(probs, k, p) + else: + next_token_ids = sampling_from_probs(probs) + + return next_token_ids + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def shutdown(self) -> None: + """Release model and memory resources.""" + logger.info("ModelRunner shutting down...") + + if self.graph_runner is not None: + self.graph_runner.shutdown() + self.graph_runner = None + if self.model is not None: + del self.model + self.model = None + if self.token_to_kv_pool is not None: + del self.token_to_kv_pool + self.token_to_kv_pool = None + if self.token_to_kv_pool_allocator is not None: + del self.token_to_kv_pool_allocator + self.token_to_kv_pool_allocator = None + if self.gdn_pool is not None: + del self.gdn_pool + self.gdn_pool = None + if self.req_to_token_pool is not None: + del self.req_to_token_pool + self.req_to_token_pool = None + self.attn_backend = None + + if self.device == "cuda": + torch.cuda.empty_cache() + gc.collect() + + logger.info("ModelRunner shutdown complete.") + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def is_generation(self) -> bool: + """True if the model is a generation (causal-LM) model.""" + return True + + @property + def sliding_window_size(self) -> Optional[int]: + """Sliding-window attention span, or ``None`` for full context.""" + hf_config = self.model_config.hf_config + if hf_config is None: + return None + text_config = getattr(hf_config, "text_config", hf_config) + return getattr(text_config, "sliding_window", None) + + +# --------------------------------------------------------------------------- +# Utility functions +# --------------------------------------------------------------------------- + + +def _compute_positions( + extend_seq_lens: torch.Tensor, + extend_prefix_lens: torch.Tensor, +) -> torch.Tensor: + """Compute per-token positions for an extend batch. + + For each sequence, positions are + ``[prefix_len, prefix_len+1, ..., prefix_len+seq_len-1]``. + The result is a flat 1-D tensor of shape ``[sum(extend_seq_lens)]``. + """ + device = extend_seq_lens.device + batch_size = extend_seq_lens.shape[0] + total_tokens = int(extend_seq_lens.sum().item()) + + if total_tokens == 0: + return torch.empty(0, dtype=torch.int64, device=device) + + positions = torch.empty(total_tokens, dtype=torch.int64, device=device) + offset = 0 + for i in range(batch_size): + seq_len = int(extend_seq_lens[i].item()) + prefix_len = int(extend_prefix_lens[i].item()) + if seq_len > 0: + positions[offset : offset + seq_len] = torch.arange( + prefix_len, + prefix_len + seq_len, + dtype=torch.int64, + device=device, + ) + offset += seq_len + + return positions diff --git a/pymllm/layers/__init__.py b/pymllm/layers/__init__.py new file mode 100644 index 000000000..2ecb13965 --- /dev/null +++ b/pymllm/layers/__init__.py @@ -0,0 +1,65 @@ +"""Layers module for pymllm.""" + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.embedding import VocabParallelEmbedding +from pymllm.layers.layer_norm import LayerNorm +from pymllm.layers.linear import ColumnParallelLinear, Linear, RowParallelLinear +from pymllm.layers.mlp import MLP, ParallelMLP +from pymllm.layers.rms_norm import GemmaRMSNorm, RMSNorm +from pymllm.layers.rms_norm_gated import RMSNormGated +from pymllm.layers.gated_delta_net import GatedDeltaNet +from pymllm.layers.rope import ( + apply_llama31_rope, + apply_llama31_rope_pos_ids, + apply_mrope, + apply_rope, + apply_rope_pos_ids, + apply_rope_with_cos_sin_cache, +) +from pymllm.layers.sampling import ( + chain_speculative_sampling, + min_p_sampling_from_probs, + sampling_from_logits, + sampling_from_probs, + softmax, + top_k_mask_logits, + top_k_renorm_probs, + top_k_sampling_from_probs, + top_k_top_p_sampling_from_logits, + top_k_top_p_sampling_from_probs, + top_p_renorm_probs, + top_p_sampling_from_probs, +) +from pymllm.layers.utils import set_weight_attrs + +__all__ = [ + "MllmBaseLayer", + "set_weight_attrs", + "VocabParallelEmbedding", + "ColumnParallelLinear", + "Linear", + "RowParallelLinear", + "MLP", + "ParallelMLP", + "LayerNorm", + "RMSNorm", + "GemmaRMSNorm", + "apply_mrope", + "apply_rope", + "apply_llama31_rope", + "apply_rope_pos_ids", + "apply_llama31_rope_pos_ids", + "apply_rope_with_cos_sin_cache", + "softmax", + "sampling_from_probs", + "sampling_from_logits", + "top_p_sampling_from_probs", + "top_k_sampling_from_probs", + "min_p_sampling_from_probs", + "top_k_top_p_sampling_from_logits", + "top_k_top_p_sampling_from_probs", + "top_p_renorm_probs", + "top_k_renorm_probs", + "top_k_mask_logits", + "chain_speculative_sampling", +] diff --git a/pymllm/layers/attention/__init__.py b/pymllm/layers/attention/__init__.py new file mode 100644 index 000000000..ae187975d --- /dev/null +++ b/pymllm/layers/attention/__init__.py @@ -0,0 +1,33 @@ +"""Attention layers and backends for pymllm.""" + +from pymllm.layers.attention.attention_backend import AttentionBackend +from pymllm.layers.attention.flashinfer_backend import ( + DecodeMetadata, + FlashInferAttnBackend, + PrefillMetadata, + WrapperDispatch, + should_use_tensor_core, +) +from pymllm.layers.attention.gdn_backend import GDNAttnBackend +from pymllm.layers.attention.hybrid_backend import HybridAttnBackend +from pymllm.layers.attention.radix_attention import AttentionType, RadixAttention +from pymllm.layers.attention.radix_linear_attention import RadixLinearAttention + +__all__ = [ + # Base + "AttentionBackend", + # RadixAttention + "AttentionType", + "RadixAttention", + # RadixLinearAttention (GDN) + "RadixLinearAttention", + # FlashInfer backend + "FlashInferAttnBackend", + "DecodeMetadata", + "PrefillMetadata", + "WrapperDispatch", + "should_use_tensor_core", + # GDN + Hybrid backends + "GDNAttnBackend", + "HybridAttnBackend", +] diff --git a/pymllm/layers/attention/attention_backend.py b/pymllm/layers/attention/attention_backend.py new file mode 100644 index 000000000..fe168c2d2 --- /dev/null +++ b/pymllm/layers/attention/attention_backend.py @@ -0,0 +1,165 @@ +"""Abstract base class for pymllm attention backends. + +Every concrete backend (FlashInfer, Triton, torch-native, …) must implement +at minimum: + + * ``init_forward_metadata`` – called once per batch before the model forward. + * ``forward_extend`` – prefill / extend attention. + * ``forward_decode`` – single-token decode attention. + +The public ``forward`` method dispatches to the correct variant based on +``forward_batch.forward_mode``. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional + +import torch + +if TYPE_CHECKING: + from pymllm.engine.forward_batch import ForwardBatch, ForwardMode + from pymllm.layers.attention.radix_attention import RadixAttention + + +class AttentionBackend(ABC): + """Abstract base class for attention backends. + + All concrete backends inherit from this class and implement the abstract + methods below. + """ + + # ------------------------------------------------------------------ + # Core interface – must be implemented by every backend + # ------------------------------------------------------------------ + + @abstractmethod + def init_forward_metadata(self, forward_batch: "ForwardBatch") -> None: + """Prepare per-batch metadata before the model's attention layers run. + + For FlashInfer this plans the KV-index arrays and calls + ``wrapper.begin_forward``; for Triton / torch-native this is a no-op. + Must be called once per batch *before* ``model.forward``. + """ + raise NotImplementedError + + @abstractmethod + def forward_decode( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", + forward_batch: "ForwardBatch", + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + """Run attention for a decode step (one new token per sequence).""" + raise NotImplementedError + + @abstractmethod + def forward_extend( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", + forward_batch: "ForwardBatch", + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + """Run attention for a prefill / extend step.""" + raise NotImplementedError + + # ------------------------------------------------------------------ + # Dispatch – shared logic; do not override in normal backends + # ------------------------------------------------------------------ + + def forward( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", + forward_batch: "ForwardBatch", + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + """Dispatch to ``forward_decode`` or ``forward_extend`` based on mode. + + For IDLE batches a zero-filled output tensor is returned without any + compute. + """ + if forward_batch.forward_mode.is_idle(): + # Return empty output without computation. + return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim) + elif forward_batch.forward_mode.is_decode(): + return self.forward_decode( + q, k, v, layer, forward_batch, save_kv_cache=save_kv_cache, **kwargs + ) + else: + return self.forward_extend( + q, k, v, layer, forward_batch, save_kv_cache=save_kv_cache, **kwargs + ) + + # ------------------------------------------------------------------ + # GDN linear-attention interface (used by HybridAttnBackend) + # ------------------------------------------------------------------ + + def forward_gdn( + self, + layer: "RadixLinearAttention", + forward_batch: "ForwardBatch", + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + """Run GDN linear-attention for one layer. + + Only implemented by backends that support hybrid (full + GDN) + architectures. The default raises ``NotImplementedError``. + """ + raise NotImplementedError( + f"{type(self).__name__} does not support GDN linear attention. " + "Use HybridAttnBackend for hybrid full+GDN models." + ) + + # ------------------------------------------------------------------ + # Optional CUDA-graph interface + # ------------------------------------------------------------------ + + def get_cuda_graph_seq_len_fill_value(self) -> int: + """Fill value used to pad ``seq_lens`` tensors for CUDA-graph capture. + + Most backends use ``1`` (not ``0``) to avoid division-by-zero in + attention kernels. + """ + raise NotImplementedError + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int) -> None: + """Allocate shared CUDA-graph state (buffers reused across captures).""" + raise NotImplementedError + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + forward_mode: "ForwardMode", + ) -> None: + """Set up per-batch metadata for capturing a CUDA graph.""" + raise NotImplementedError + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + forward_mode: "ForwardMode", + seq_lens_cpu: Optional[torch.Tensor], + ) -> None: + """Update metadata when replaying a captured CUDA graph.""" + raise NotImplementedError diff --git a/pymllm/layers/attention/flashinfer_backend.py b/pymllm/layers/attention/flashinfer_backend.py new file mode 100644 index 000000000..479fb5cec --- /dev/null +++ b/pymllm/layers/attention/flashinfer_backend.py @@ -0,0 +1,964 @@ +"""FlashInfer attention backend for pymllm. + + * No model-runner object -- constructor takes explicit scalar / tensor params. + * No tensor-parallelism head splitting (handled at the model layer level). + * No speculative decoding support. + * ``KVPool`` API: + - ``get_kv_buffer(layer_id)`` returns ``(k_buf, v_buf)`` each shaped + ``[buf_len, num_heads, head_dim]``. + - ``set_kv_buffer(layer_id, indices, k, v)`` -- no scale arguments. + +Supports: + * Single-wrapper mode (full context, no sliding window) + * Sliding-window mode (two wrappers: window + full) + * CUDA-graph capture / replay for decode and target-verify passes. +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from enum import Enum, auto +from typing import List, Optional, Union + +import torch + +from pymllm.engine.forward_batch import ForwardBatch, ForwardMode +from pymllm.layers.attention.attention_backend import AttentionBackend +from mllm_kernel.cuda.jit.create_kv_indices import create_kv_indices + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Optional FlashInfer import +# --------------------------------------------------------------------------- + +_flashinfer_available = False +try: + from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, + ) + + try: + from flashinfer import fast_decode_plan + from functools import partial as _partial + + _has_fast_decode_plan = True + except ImportError: + _has_fast_decode_plan = False + + from flashinfer.cascade import merge_state + + _flashinfer_available = True +except ImportError: + logger.warning( + "flashinfer is not installed; FlashInferAttnBackend will raise " + "NotImplementedError if used." + ) + +# --------------------------------------------------------------------------- +# Global workspace buffer (shared across all FlashInfer wrapper instances) +# --------------------------------------------------------------------------- + +_global_workspace_buffer: Optional[torch.Tensor] = None + +# Default workspace size (128 MB); can be overridden via environment variable. +_DEFAULT_WORKSPACE_BYTES = int( + os.environ.get("PYMLLM_FLASHINFER_WORKSPACE_SIZE", 128 * 1024 * 1024) +) + +# --------------------------------------------------------------------------- +# Enums / dataclasses +# --------------------------------------------------------------------------- + + +class WrapperDispatch(Enum): + """Indicates which wrapper to use for a given attention layer.""" + + SLIDING_WINDOW = auto() + CROSS_ATTENTION = auto() + + +@dataclass +class DecodeMetadata: + """Per-batch metadata for a decode step.""" + + decode_wrappers: "List[BatchDecodeWithPagedKVCacheWrapper]" + + +@dataclass +class PrefillMetadata: + """Per-batch metadata for a prefill / extend step.""" + + prefill_wrappers: "List[BatchPrefillWithPagedKVCacheWrapper]" + use_ragged: bool + extend_no_prefix: bool + + +# --------------------------------------------------------------------------- +# CUDA kernel – build the flat kv_indices array for FlashInfer +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# Helper – choose whether to use tensor cores for decode +# --------------------------------------------------------------------------- + + +def should_use_tensor_core( + kv_cache_dtype: torch.dtype, + num_attention_heads: int, + num_kv_heads: int, +) -> bool: + """Return whether FlashInfer decode should use tensor cores. + + For FP8 we always use tensor cores. For fp16 / bf16 we use them when + the GQA group size (num_attention_heads / num_kv_heads) is ≥ 4, which + fuses the head group with the token dimension in the MMA instruction. + """ + env_override = os.environ.get("PYMLLM_FLASHINFER_USE_TENSOR_CORE") + if env_override is not None: + return env_override.lower() == "true" + + try: + from flashinfer.decode import _grouped_size_compiled_for_decode_kernels + + return not _grouped_size_compiled_for_decode_kernels( + num_attention_heads, num_kv_heads + ) + except (ImportError, AttributeError): + pass + + gqa_group_size = num_attention_heads // num_kv_heads + if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + return True + if kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16): + return gqa_group_size >= 4 + return False + + +# --------------------------------------------------------------------------- +# FlashInferAttnBackend +# --------------------------------------------------------------------------- + + +class FlashInferAttnBackend(AttentionBackend): + """FlashInfer-based attention backend for pymllm. + + This class does not depend on a ``ModelRunner`` object. Instead it takes + all required configuration explicitly so that it can be constructed + independently of any particular model runner. + + Parameters + ---------- + num_heads + Number of query heads per device (after any TP sharding). + num_kv_heads + Number of KV heads per device. + head_dim + Per-head dimension for Q and K. + kv_cache_dtype + ``torch.dtype`` of the KV cache (e.g. ``torch.float16``). + q_dtype + ``torch.dtype`` of the query tensor. + max_context_len + Maximum sequence length the model supports. + req_to_token + The ``[max_reqs, max_context_len]`` int32 tensor from + ``ReqToTokenPool.req_to_token``. + device + Target device (e.g. ``torch.device("cuda")``) + max_req_pool_size + Maximum number of concurrent requests (= ``ReqToTokenPool.size``). + Used to pre-allocate ``kv_indptr`` / ``kv_last_page_len`` buffers. + sliding_window_size + When not ``None``, enables sliding-window attention mode which + allocates two wrapper sets (window + full context). + skip_prefill + When ``True``, skip creating prefill wrappers (for backends that only + perform decode, e.g. multi-step draft backends). + kv_indptr_buf + Optional pre-allocated ``kv_indptr`` buffer. Used when sharing + buffers across multiple backend instances (e.g. multi-step draft). + kv_last_page_len_buf + Optional pre-allocated ``kv_last_page_len`` buffer. + init_new_workspace + When ``True`` allocate a fresh workspace buffer instead of reusing the + global one. + """ + + def __init__( + self, + num_heads: int, + num_kv_heads: int, + head_dim: int, + kv_cache_dtype: torch.dtype, + q_dtype: torch.dtype, + max_context_len: int, + req_to_token: torch.Tensor, + device: torch.device, + max_req_pool_size: int, + sliding_window_size: Optional[int] = None, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + kv_last_page_len_buf: Optional[torch.Tensor] = None, + init_new_workspace: bool = False, + ): + if not _flashinfer_available: + raise RuntimeError( + "flashinfer is required for FlashInferAttnBackend but is not " + "installed. Run: pip install flashinfer-python" + ) + + super().__init__() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.kv_cache_dtype = kv_cache_dtype + self.q_dtype = q_dtype + self.max_context_len = max_context_len + self.req_to_token = req_to_token + self.device = device + self.skip_prefill = skip_prefill + + # Tensor-core preference for decode + self.decode_use_tensor_cores = should_use_tensor_core( + kv_cache_dtype, num_heads, num_kv_heads + ) + + # Sliding-window / cross-attention wrapper dispatch + if sliding_window_size is not None: + self.num_wrappers = 2 + self.dispatch_reason: Optional[WrapperDispatch] = ( + WrapperDispatch.SLIDING_WINDOW + ) + self.sliding_window_size: Optional[int] = sliding_window_size + else: + self.num_wrappers = 1 + self.dispatch_reason = None + self.sliding_window_size = None + + # ------------------------------------------------------------------ + # Workspace buffer + # ------------------------------------------------------------------ + global _global_workspace_buffer + if _global_workspace_buffer is None: + _global_workspace_buffer = torch.empty( + _DEFAULT_WORKSPACE_BYTES, + dtype=torch.uint8, + device=device, + ) + if init_new_workspace: + self.workspace_buffer = torch.empty( + _DEFAULT_WORKSPACE_BYTES, + dtype=torch.uint8, + device=device, + ) + else: + self.workspace_buffer = _global_workspace_buffer + + # ------------------------------------------------------------------ + # kv_indptr [num_wrappers × (max_req_pool_size + 1)] + # kv_last_page_len [max_req_pool_size] + # ------------------------------------------------------------------ + if kv_indptr_buf is None: + self.kv_indptr: List[torch.Tensor] = [ + torch.zeros((max_req_pool_size + 1,), dtype=torch.int32, device=device) + for _ in range(self.num_wrappers) + ] + else: + assert self.num_wrappers == 1 + self.kv_indptr = [kv_indptr_buf] + + if kv_last_page_len_buf is None: + self.kv_last_page_len = torch.ones( + (max_req_pool_size,), dtype=torch.int32, device=device + ) + else: + assert self.num_wrappers == 1 + self.kv_last_page_len = kv_last_page_len_buf + + # qo_indptr – only needed for prefill + if not skip_prefill: + self.qo_indptr: List[torch.Tensor] = [ + torch.zeros((max_req_pool_size + 1,), dtype=torch.int32, device=device) + for _ in range(self.num_wrappers) + ] + + # ------------------------------------------------------------------ + # Create FlashInfer wrappers + # ------------------------------------------------------------------ + self.prefill_wrapper_ragged: Optional[ + "BatchPrefillWithRaggedKVCacheWrapper" + ] = None + self.prefill_wrappers_paged: List["BatchPrefillWithPagedKVCacheWrapper"] = [] + self.decode_wrappers: List["BatchDecodeWithPagedKVCacheWrapper"] = [] + + if not skip_prefill: + self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( + self.workspace_buffer, "NHD" + ) + + for _ in range(self.num_wrappers): + if not skip_prefill: + self.prefill_wrappers_paged.append( + BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") + ) + self.decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_tensor_cores=self.decode_use_tensor_cores, + ) + ) + + # ------------------------------------------------------------------ + # Indices updaters + # ------------------------------------------------------------------ + if not skip_prefill: + self.indices_updater_prefill = _FlashInferIndicesUpdaterPrefill(self) + self.indices_updater_decode = _FlashInferIndicesUpdaterDecode(self) + + # Per-batch metadata set by init_forward_metadata + self.forward_metadata: Optional[Union[DecodeMetadata, PrefillMetadata]] = None + + # CUDA-graph metadata stores + self.decode_cuda_graph_metadata: dict = {} + self.prefill_cuda_graph_metadata: dict = {} + + # ------------------------------------------------------------------ + # init_forward_metadata + # ------------------------------------------------------------------ + + def init_forward_metadata(self, forward_batch: ForwardBatch) -> None: + """Prepare FlashInfer wrappers for the current batch. + + Must be called once per batch before the model's ``forward`` method. + """ + if forward_batch.forward_mode.is_decode_or_idle(): + self.indices_updater_decode.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_cpu, + forward_batch.seq_lens_sum, + decode_wrappers=self.decode_wrappers, + ) + self.forward_metadata = DecodeMetadata(self.decode_wrappers) + else: + # Extend / prefill + prefix_lens = forward_batch.extend_prefix_lens + extend_no_prefix = ( + forward_batch.extend_prefix_lens_cpu is not None + and not any(forward_batch.extend_prefix_lens_cpu) + ) + use_ragged = extend_no_prefix + + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_cpu, + forward_batch.seq_lens_sum, + prefix_lens=prefix_lens, + prefill_wrappers=self.prefill_wrappers_paged, + use_ragged=use_ragged, + ) + self.forward_metadata = PrefillMetadata( + self.prefill_wrappers_paged, + use_ragged=use_ragged, + extend_no_prefix=extend_no_prefix, + ) + + # ------------------------------------------------------------------ + # forward_extend + # ------------------------------------------------------------------ + + def forward_extend( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", # noqa: F821 + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + from pymllm.layers.attention.radix_attention import RadixAttention + + assert isinstance(layer, RadixAttention) + meta: PrefillMetadata = self.forward_metadata + + prefill_wrapper_paged = meta.prefill_wrappers[self._get_wrapper_idx(layer)] + cache_loc = forward_batch.out_cache_loc + + # Write K/V into the pool + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer.layer_id, cache_loc, k, v + ) + + q_3d = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + + if not meta.use_ragged: + # Paged-only path: uses the full KV cache (prefix + extend). + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + # Reshape to [buf_len, page_size=1, num_heads, head_dim] for FlashInfer. + paged_kv = (k_cache.unsqueeze(1), v_cache.unsqueeze(1)) + + o = prefill_wrapper_paged.forward( + q_3d, + paged_kv, + causal=not layer.is_cross_attention, + sm_scale=layer.scaling, + window_left=layer.sliding_window_size, + logits_soft_cap=layer.logit_cap if layer.logit_cap > 0 else None, + ) + else: + # Ragged path: query attends only to the new (ragged) K/V; + # prefix K/V is in the paged pool. + if k is None: + # Fallback: load K/V from the pool. + k_buf, v_buf = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + k = k_buf + v = v_buf + + k_3d = k.view(-1, layer.tp_k_head_num, layer.head_dim) + v_3d = v.view(-1, layer.tp_v_head_num, layer.v_head_dim) + + if meta.extend_no_prefix: + # Pure prefill – no prefix at all. + o = self.prefill_wrapper_ragged.forward( + q_3d, + k_3d, + v_3d, + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=(layer.logit_cap if layer.logit_cap > 0 else None), + ) + else: + # Extend with prefix: merge ragged (new) and paged (prefix). + o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( + q_3d, + k_3d, + v_3d, + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=(layer.logit_cap if layer.logit_cap > 0 else None), + ) + + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + paged_kv = (k_cache.unsqueeze(1), v_cache.unsqueeze(1)) + o2, s2 = prefill_wrapper_paged.forward_return_lse( + q_3d, + paged_kv, + causal=False, + sm_scale=layer.scaling, + logits_soft_cap=(layer.logit_cap if layer.logit_cap > 0 else None), + ) + + o, _ = merge_state(o1, s1, o2, s2) + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + # ------------------------------------------------------------------ + # forward_decode + # ------------------------------------------------------------------ + + def forward_decode( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", # noqa: F821 + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + from pymllm.layers.attention.radix_attention import RadixAttention + + assert isinstance(layer, RadixAttention) + meta: DecodeMetadata = self.forward_metadata + + decode_wrapper = meta.decode_wrappers[self._get_wrapper_idx(layer)] + cache_loc = forward_batch.out_cache_loc + + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer.layer_id, cache_loc, k, v + ) + + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + paged_kv = (k_cache.unsqueeze(1), v_cache.unsqueeze(1)) + + o = decode_wrapper.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + paged_kv, + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap if layer.logit_cap > 0 else None, + ) + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + # ------------------------------------------------------------------ + # CUDA-graph support + # ------------------------------------------------------------------ + + def get_cuda_graph_seq_len_fill_value(self) -> int: + return 1 + + def init_cuda_graph_state( + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, + ) -> None: + """Allocate CUDA-graph shared state buffers.""" + if kv_indices_buf is None: + cuda_graph_kv_indices = torch.zeros( + (max_num_tokens * self.max_context_len,), + dtype=torch.int32, + device=self.device, + ) + else: + cuda_graph_kv_indices = kv_indices_buf + + self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [ + cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) + ] + + if not self.skip_prefill: + self.cuda_graph_custom_mask = torch.zeros( + (max_num_tokens * self.max_context_len,), + dtype=torch.uint8, + device=self.device, + ) + self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr] + self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr] + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + forward_mode: ForwardMode, + ) -> None: + """Set up metadata for CUDA-graph capture of a decode step.""" + if not forward_mode.is_decode_or_idle(): + raise ValueError( + "CUDA-graph capture is only supported for decode / idle modes." + ) + + decode_wrappers = [] + for i in range(self.num_wrappers): + decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=self.decode_use_tensor_cores, + paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1], + paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_tokens], + ) + ) + + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_decode.update( + req_pool_indices, + seq_lens, + seq_lens.cpu(), + seq_lens_sum, + decode_wrappers=decode_wrappers, + ) + self.decode_cuda_graph_metadata[bs] = decode_wrappers + self.forward_metadata = DecodeMetadata(decode_wrappers) + + if _has_fast_decode_plan: + for i in range(self.num_wrappers): + decode_wrappers[i].begin_forward = _partial( + fast_decode_plan, decode_wrappers[i] + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + forward_mode: ForwardMode, + seq_lens_cpu: Optional[torch.Tensor], + ) -> None: + """Update metadata when replaying a CUDA graph for decode.""" + if not forward_mode.is_decode_or_idle(): + raise ValueError( + "CUDA-graph replay is only supported for decode / idle modes." + ) + + self.indices_updater_decode.update( + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_cpu[:bs] if seq_lens_cpu is not None else None, + seq_lens_sum, + decode_wrappers=self.decode_cuda_graph_metadata[bs], + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _get_wrapper_idx(self, layer) -> int: + """Return the wrapper index for the given attention layer.""" + if self.num_wrappers == 1: + return 0 + if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + # Wrapper 0 → sliding window attention. + # Wrapper 1 → full-context attention. + return int(layer.sliding_window_size == -1) + raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}") + + +# --------------------------------------------------------------------------- +# _FlashInferIndicesUpdaterDecode +# --------------------------------------------------------------------------- + + +class _FlashInferIndicesUpdaterDecode: + """Populates ``kv_indptr`` / ``kv_indices`` and calls + ``wrapper.begin_forward`` before every decode step. + """ + + def __init__(self, backend: FlashInferAttnBackend): + self.num_qo_heads = backend.num_heads + self.num_kv_heads = backend.num_kv_heads + self.head_dim = backend.head_dim + self.data_type = backend.kv_cache_dtype + self.q_data_type = backend.q_dtype + self.sliding_window_size = backend.sliding_window_size + self.backend = backend + + self.kv_indptr = backend.kv_indptr + self.kv_last_page_len = backend.kv_last_page_len + self.req_to_token = backend.req_to_token + + def update( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], + seq_lens_sum: int, + decode_wrappers: "List[BatchDecodeWithPagedKVCacheWrapper]", + kv_start_idx: Optional[torch.Tensor] = None, + ) -> None: + if self.backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + self._update_sliding_window( + req_pool_indices, + seq_lens, + seq_lens_cpu, + seq_lens_sum, + decode_wrappers, + ) + else: + # Single-wrapper: full-context decode. Build kv_indptr/kv_indices + # and call FlashInfer's plan function via the CUDA kernel. + bs = len(req_pool_indices) + kv_indptr = self.kv_indptr[0] + + # Fill kv_indptr: prefix sums of paged_kernel_lens. + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr_sliced = kv_indptr[: bs + 1] + + if seq_lens_cpu is not None: + seq_lens_sum = int(seq_lens_cpu.sum().item()) + else: + seq_lens_sum = int(seq_lens.sum().item()) + + # Allocate KV indices buffer. + if decode_wrappers and decode_wrappers[0].is_cuda_graph_enabled: + kv_indices = decode_wrappers[0]._paged_kv_indices_buf + else: + kv_indices = torch.empty( + seq_lens_sum, dtype=torch.int32, device=self.req_to_token.device + ) + + # Use high-performance CUDA kernel to populate kv_indices. + create_kv_indices( + self.req_to_token, + req_pool_indices.to(torch.int32), + seq_lens.to(torch.int32), + kv_indptr_sliced, + None, + kv_indices, + ) + + decode_wrappers = decode_wrappers or self.decode_wrappers + decode_wrappers[0].begin_forward( + kv_indptr_sliced, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + data_type=self.data_type, + q_data_type=self.q_data_type, + non_blocking=True, + ) + + def _update_sliding_window( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], + seq_lens_sum: int, + decode_wrappers: "List[BatchDecodeWithPagedKVCacheWrapper]", + ) -> None: + assert self.sliding_window_size is not None + for wrapper_id in range(2): + if wrapper_id == 0: + # Sliding-window attention: clamp to window size + 1 + paged_kernel_lens = torch.clamp( + seq_lens, max=self.sliding_window_size + 1 + ) + paged_kernel_lens_sum = int(paged_kernel_lens.sum().item()) + kv_start_idx = seq_lens - paged_kernel_lens + seq_lens_cpu_tmp = ( + torch.clamp(seq_lens_cpu, max=self.sliding_window_size + 1) + if seq_lens_cpu is not None + else None + ) + else: + # Full-context attention + paged_kernel_lens = seq_lens + paged_kernel_lens_sum = seq_lens_sum + kv_start_idx = None + seq_lens_cpu_tmp = seq_lens_cpu + + bs = len(req_pool_indices) + kv_indptr = self.kv_indptr[wrapper_id] + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr_sliced = kv_indptr[: bs + 1] + + if decode_wrappers and decode_wrappers[wrapper_id].is_cuda_graph_enabled: + kv_indices = decode_wrappers[wrapper_id]._paged_kv_indices_buf + else: + kv_indices = torch.empty( + paged_kernel_lens_sum, + dtype=torch.int32, + device=self.req_to_token.device, + ) + + # High-performance CUDA kernel populates kv_indices from req_to_token. + create_kv_indices( + self.req_to_token, + req_pool_indices.to(torch.int32), + paged_kernel_lens.to(torch.int32), + kv_indptr_sliced, + kv_start_idx.to(torch.int32) if kv_start_idx is not None else None, + kv_indices, + ) + + decode_wrappers[wrapper_id].begin_forward( + kv_indptr_sliced, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + data_type=self.data_type, + q_data_type=self.q_data_type, + non_blocking=True, + ) + + +# --------------------------------------------------------------------------- +# _FlashInferIndicesUpdaterPrefill +# --------------------------------------------------------------------------- + + +class _FlashInferIndicesUpdaterPrefill: + """Populates indices and calls ``wrapper.begin_forward`` before extend.""" + + def __init__(self, backend: FlashInferAttnBackend): + self.num_qo_heads = backend.num_heads + self.num_kv_heads = backend.num_kv_heads + self.head_dim = backend.head_dim + self.data_type = backend.kv_cache_dtype + self.q_data_type = backend.q_dtype + self.sliding_window_size = backend.sliding_window_size + self.backend = backend + + self.kv_indptr = backend.kv_indptr + self.kv_last_page_len = backend.kv_last_page_len + self.qo_indptr = backend.qo_indptr + self.req_to_token = backend.req_to_token + self.prefill_wrapper_ragged = backend.prefill_wrapper_ragged + + def update( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], + seq_lens_sum: int, + prefix_lens: Optional[torch.Tensor], + prefill_wrappers: "List[BatchPrefillWithPagedKVCacheWrapper]", + use_ragged: bool, + ) -> None: + if self.backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + self._update_sliding_window( + req_pool_indices, + seq_lens, + seq_lens_cpu, + seq_lens_sum, + prefix_lens, + prefill_wrappers, + use_ragged, + ) + else: + if use_ragged: + paged_kernel_lens = prefix_lens + paged_kernel_lens_sum = paged_kernel_lens.sum().item() + else: + paged_kernel_lens = seq_lens + paged_kernel_lens_sum = seq_lens_sum + + self._call_begin_forward( + self.prefill_wrapper_ragged, + prefill_wrappers[0], + req_pool_indices, + paged_kernel_lens, + paged_kernel_lens_sum, + seq_lens, + prefix_lens, + kv_start_idx=None, + kv_indptr=self.kv_indptr[0], + qo_indptr=self.qo_indptr[0], + use_ragged=use_ragged, + ) + + def _update_sliding_window( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], + seq_lens_sum: int, + prefix_lens: Optional[torch.Tensor], + prefill_wrappers: "List[BatchPrefillWithPagedKVCacheWrapper]", + use_ragged: bool, + ) -> None: + assert self.sliding_window_size is not None + for wrapper_id in range(2): + if wrapper_id == 0: + # Sliding-window portion uses a limited context window. + extend_lens = seq_lens - prefix_lens + paged_kernel_lens = torch.minimum( + seq_lens, + torch.tensor(self.sliding_window_size, device=seq_lens.device) + + extend_lens, + ) + paged_kernel_lens_sum = int(paged_kernel_lens.sum().item()) + kv_start_idx = seq_lens - paged_kernel_lens + else: + # Full-context portion. + paged_kernel_lens = seq_lens + paged_kernel_lens_sum = seq_lens_sum + kv_start_idx = None + + kv_indptr = self.kv_indptr[wrapper_id] + qo_indptr = self.qo_indptr[wrapper_id] + + self._call_begin_forward( + self.prefill_wrapper_ragged, + prefill_wrappers[wrapper_id], + req_pool_indices, + paged_kernel_lens, + paged_kernel_lens_sum, + seq_lens, + prefix_lens, + kv_start_idx=kv_start_idx, + kv_indptr=kv_indptr, + qo_indptr=qo_indptr, + use_ragged=use_ragged, + ) + + def _call_begin_forward( + self, + wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper", + wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper", + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, + seq_lens: torch.Tensor, + prefix_lens: Optional[torch.Tensor], + kv_start_idx: Optional[torch.Tensor], + kv_indptr: torch.Tensor, + qo_indptr: torch.Tensor, + use_ragged: bool, + ) -> None: + bs = len(seq_lens) + + # Build kv_indptr and kv_indices using the CUDA kernel. + kv_indptr_sliced = kv_indptr[: bs + 1] + kv_indptr_sliced[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_indices = torch.empty( + paged_kernel_lens_sum + 256, + dtype=torch.int32, + device=req_pool_indices.device, + ) + + create_kv_indices( + self.req_to_token, + req_pool_indices.to(torch.int32), + paged_kernel_lens.to(torch.int32), + kv_indptr_sliced, + kv_start_idx.to(torch.int32) if kv_start_idx is not None else None, + kv_indices, + ) + + # Build qo_indptr (number of new tokens per sequence). + if prefix_lens is not None: + extend_lens = seq_lens - prefix_lens + else: + extend_lens = seq_lens + qo_indptr_sliced = qo_indptr[: bs + 1] + qo_indptr_sliced[1:] = torch.cumsum(extend_lens, dim=0) + + # Plan the ragged wrapper (new tokens only). + if use_ragged: + wrapper_ragged.begin_forward( + qo_indptr_sliced, + qo_indptr_sliced, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + q_data_type=self.q_data_type, + ) + + # Plan the paged wrapper (cached prefix tokens). + wrapper_paged.begin_forward( + qo_indptr_sliced, + kv_indptr_sliced, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + q_data_type=self.q_data_type, + kv_data_type=self.data_type, + non_blocking=True, + ) diff --git a/pymllm/backends/cuda/__init__.py b/pymllm/layers/attention/gdn.py similarity index 100% rename from pymllm/backends/cuda/__init__.py rename to pymllm/layers/attention/gdn.py diff --git a/pymllm/layers/attention/gdn_backend.py b/pymllm/layers/attention/gdn_backend.py new file mode 100644 index 000000000..2b6e27b48 --- /dev/null +++ b/pymllm/layers/attention/gdn_backend.py @@ -0,0 +1,660 @@ +"""GDN attention backend -- pooled-state GDN computation for hybrid models. + +Performs GDN (Gated Delta Net) linear-attention using externalized state +stored in a :class:`~pymllm.mem_cache.memory_pool.GDNPool`. Supports +both extend (prefill) and decode paths with FlashInfer kernels. + +This backend is not used directly; it is wrapped by +:class:`~pymllm.layers.attention.hybrid_backend.HybridAttnBackend`. +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn.functional as F + +if TYPE_CHECKING: + from pymllm.engine.forward_batch import ForwardBatch + from pymllm.layers.attention.radix_linear_attention import RadixLinearAttention + from pymllm.mem_cache.memory_pool import GDNPool + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Server config: gdn_decode_backend override +# --------------------------------------------------------------------------- + + +def _get_gdn_decode_backend_override() -> str: + """Read ``server.gdn_decode_backend`` from GlobalConfig. + + Returns one of: ``"auto"``, ``"flashinfer"``, ``"mllm_kernel"``, ``"pytorch"``. + """ + try: + from pymllm.configs import get_global_config + return get_global_config().server.gdn_decode_backend + except Exception: + return "auto" + + +# --------------------------------------------------------------------------- +# mllm-kernel GDN decode (lazy import, SM80+) +# --------------------------------------------------------------------------- + +_mllm_gdn_decode = None + + +def _get_mllm_gdn_decode(): + """Lazy import for mllm-kernel fused GDN decode CUDA kernel.""" + global _mllm_gdn_decode + if _mllm_gdn_decode is None: + try: + from mllm_kernel.cuda.jit.gdn_decode import gdn_decode + + _mllm_gdn_decode = gdn_decode + logger.info("GDNAttnBackend: [probe] mllm-kernel GDN decode available (SM80+)") + except (ImportError, RuntimeError) as e: + logger.info("GDNAttnBackend: [probe] mllm-kernel GDN decode not available: %s", e) + _mllm_gdn_decode = False + return _mllm_gdn_decode if _mllm_gdn_decode is not False else None + + +# --------------------------------------------------------------------------- +# FlashInfer GDN kernel (lazy import) +# --------------------------------------------------------------------------- + +_flashinfer_available: Optional[bool] = None +_fi_chunk_gated_delta_rule = None +_fi_gated_delta_rule_decode = None + + +def _get_flashinfer_gdn(): + """Lazy import for FlashInfer GDN kernels (prefill + decode).""" + global _flashinfer_available, _fi_chunk_gated_delta_rule, _fi_gated_delta_rule_decode + if _flashinfer_available is None: + try: + os.environ.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1") + _flashinfer_available = ( + torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 9 + ) + if not _flashinfer_available: + logger.info( + "GDNAttnBackend: [probe] FlashInfer GDN not available (requires SM90+, " + "current SM%d%d)", *torch.cuda.get_device_capability() + ) + return _flashinfer_available, None, None + + from flashinfer.gdn_prefill import chunk_gated_delta_rule + _fi_chunk_gated_delta_rule = chunk_gated_delta_rule + + try: + from flashinfer.gdn_decode import gated_delta_rule_decode_pretranspose + _fi_gated_delta_rule_decode = gated_delta_rule_decode_pretranspose + logger.info("GDNAttnBackend: [probe] FlashInfer GDN available (prefill + decode)") + except ImportError: + logger.info( + "GDNAttnBackend: [probe] FlashInfer GDN partially available " + "(prefill only, decode not found)" + ) + except (ImportError, RuntimeError) as e: + logger.info( + "GDNAttnBackend: [probe] FlashInfer GDN not available: %s", e + ) + _flashinfer_available = False + return _flashinfer_available, _fi_chunk_gated_delta_rule, _fi_gated_delta_rule_decode + + +# --------------------------------------------------------------------------- +# GDN gating computation +# --------------------------------------------------------------------------- + + +def _gdn_gating( + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute GDN gating factors. + + Returns + ------- + g : log-space decay factor: -exp(A_log) * softplus(a + dt_bias) + beta : update gate: sigmoid(b) + """ + g = -torch.exp(A_log) * F.softplus(a + dt_bias) + beta = torch.sigmoid(b) + return g, beta + + +# --------------------------------------------------------------------------- +# Forward metadata +# --------------------------------------------------------------------------- + + +@dataclass +class GDNForwardMetadata: + """Per-batch metadata for GDN backend.""" + + cache_indices: torch.Tensor # [batch_size] = req_pool_indices + cu_seqlens: Optional[torch.Tensor] = None # extend only + + +# --------------------------------------------------------------------------- +# GDNAttnBackend +# --------------------------------------------------------------------------- + + +class GDNAttnBackend: + """GDN linear-attention backend using pooled states. + + Handles both extend (prefill) and decode paths for GDN layers. + Uses FlashInfer kernels when available (SM90+), with PyTorch fallback. + + Parameters + ---------- + gdn_pool + Pre-allocated :class:`~pymllm.mem_cache.memory_pool.GDNPool`. + device + Target device. + """ + + def __init__(self, gdn_pool: "GDNPool", device: torch.device): + self.gdn_pool = gdn_pool + self.device = device + self.forward_metadata: Optional[GDNForwardMetadata] = None + + # Pre-check FlashInfer availability + self._use_flashinfer, _, _ = _get_flashinfer_gdn() + + # One-shot flags to log the selected backend on first actual forward call + self._decode_backend_logged = False + self._extend_backend_logged = False + + def init_forward_metadata(self, forward_batch: "ForwardBatch") -> None: + """Prepare GDN metadata from the current forward batch.""" + cache_indices = forward_batch.req_pool_indices.to(torch.int64) + + cu_seqlens = None + if forward_batch.forward_mode.is_extend(): + # Build cu_seqlens from extend_seq_lens + if forward_batch.extend_seq_lens is not None: + seq_lens = forward_batch.extend_seq_lens.to(torch.int64) + cu_seqlens = torch.zeros( + len(seq_lens) + 1, + dtype=torch.int64, + device=self.device, + ) + torch.cumsum(seq_lens, dim=0, out=cu_seqlens[1:]) + + self.forward_metadata = GDNForwardMetadata( + cache_indices=cache_indices, + cu_seqlens=cu_seqlens, + ) + + # ------------------------------------------------------------------ + # CUDA-graph interface + # ------------------------------------------------------------------ + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int) -> None: + """Allocate CUDA-graph state for GDN backend. + + The GDN pool buffers are already pre-allocated at fixed addresses, + so we only need to allocate the metadata tensor. + """ + self._cuda_graph_cache_indices = torch.zeros( + (max_bs,), dtype=torch.int64, device=self.device + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + ) -> None: + """Set up GDN metadata for CUDA-graph capture (decode only).""" + self._cuda_graph_cache_indices[:bs].copy_( + req_pool_indices[:bs].to(torch.int64) + ) + self.forward_metadata = GDNForwardMetadata( + cache_indices=self._cuda_graph_cache_indices[:bs], + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + ) -> None: + """Update GDN metadata for CUDA-graph replay (decode only).""" + self._cuda_graph_cache_indices[:bs].copy_( + req_pool_indices[:bs].to(torch.int64) + ) + self.forward_metadata = GDNForwardMetadata( + cache_indices=self._cuda_graph_cache_indices[:bs], + ) + + # ------------------------------------------------------------------ + # Forward: decode + # ------------------------------------------------------------------ + + def forward_decode( + self, + layer: "RadixLinearAttention", + forward_batch: "ForwardBatch", + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + """GDN decode: one new token per request. + + Steps: + 1. Gather conv_state from pool → [bs, conv_dim, K-1] + 2. Conv1d update: shift + weighted sum for 1 new token + 3. Scatter updated conv_state back to pool + 4. SiLU → split q,k,v + 5. FlashInfer gated_delta_rule_decode (or PyTorch fallback) + """ + metadata = self.forward_metadata + cache_indices = metadata.cache_indices + gdn_idx = layer.gdn_layer_idx + bs = mixed_qkv.shape[0] + + recurrent_buf, conv_buf = self.gdn_pool.get_layer_state(gdn_idx) + conv_weight = layer.conv_weight # [conv_dim, kernel_size] + K = conv_weight.shape[1] + + # --- Conv1d decode: single-token update --- + conv_state = conv_buf[cache_indices] # [bs, conv_dim, K-1] + x = mixed_qkv.unsqueeze(-1) # [bs, conv_dim, 1] + + new_conv_state = torch.cat([conv_state[:, :, 1:], x], dim=-1) + full_window = torch.cat([conv_state, x], dim=-1) # [bs, conv_dim, K] + conv_out = (full_window * conv_weight.unsqueeze(0)).sum(dim=-1) + + conv_buf[cache_indices] = new_conv_state + + # --- SiLU activation --- + conv_out = F.silu(conv_out) + + # --- Split q, k, v --- + key_dim = layer.num_k_heads * layer.head_k_dim + value_dim = layer.num_v_heads * layer.head_v_dim + q, k, v = conv_out.split([key_dim, key_dim, value_dim], dim=-1) + q = q.view(bs, layer.num_k_heads, layer.head_k_dim) + k = k.view(bs, layer.num_k_heads, layer.head_k_dim) + v = v.view(bs, layer.num_v_heads, layer.head_v_dim) + + # --- Recurrent update --- + # Priority (when "auto"): FlashInfer SM90+ > mllm-kernel SM80+ > PyTorch + # Can be overridden via --server.gdn_decode_backend + backend = _get_gdn_decode_backend_override() + use_fi, _, fi_decode = _get_flashinfer_gdn() + mllm_gdn = _get_mllm_gdn_decode() + + use_flashinfer = ( + (backend in ("auto", "flashinfer")) + and use_fi and fi_decode is not None + and mixed_qkv.is_cuda + ) + use_mllm = ( + (backend in ("auto", "mllm_kernel")) + and not (backend == "auto" and use_flashinfer) + and mllm_gdn is not None + and mixed_qkv.is_cuda + ) + + if backend == "flashinfer" and not use_flashinfer: + logger.warning("GDNAttnBackend: gdn_decode_backend='flashinfer' requested but unavailable, falling back") + if backend == "mllm_kernel" and mllm_gdn is None: + logger.warning("GDNAttnBackend: gdn_decode_backend='mllm_kernel' requested but unavailable, falling back") + + if not self._decode_backend_logged: + if use_flashinfer: + selected = "flashinfer" + elif use_mllm: + selected = "mllm_kernel" + else: + selected = "pytorch" + logger.info( + "GDNAttnBackend: [decode] using backend=%s (config=%s)", selected, backend + ) + self._decode_backend_logged = True + + if use_flashinfer: + # FlashInfer decode (SM90+) + query_fi = q.unsqueeze(1) + key_fi = k.unsqueeze(1) + value_fi = v.unsqueeze(1) + a_fi = a.unsqueeze(1) + b_fi = b.unsqueeze(1) + + state_batch = recurrent_buf[cache_indices] + + output_fi, new_state = fi_decode( + q=query_fi, k=key_fi, v=value_fi, + state=state_batch, + A_log=layer.A_log.detach(), + a=a_fi, dt_bias=layer.dt_bias.detach(), b=b_fi, + scale=None, output=None, use_qk_l2norm=True, + ) + + recurrent_buf[cache_indices] = new_state + output = output_fi.squeeze(1) + + elif use_mllm: + # mllm-kernel fused CUDA decode (SM80+) + output = mllm_gdn( + q, k, v, a, b, + layer.A_log, layer.dt_bias, + recurrent_buf, cache_indices, + ) + + else: + # PyTorch fallback + g, beta = _gdn_gating(a, b, layer.A_log, layer.dt_bias) + output = self._decode_pytorch_fallback( + q, k, v, g, beta, recurrent_buf, cache_indices, layer + ) + + return output.reshape(bs, value_dim) + + def _decode_pytorch_fallback( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + recurrent_buf: torch.Tensor, + cache_indices: torch.Tensor, + layer: "RadixLinearAttention", + ) -> torch.Tensor: + """Pure PyTorch decode fallback for GDN with delta rule and L2 norm. + + Matches the sglang Triton kernel (fused_sigmoid_gating_delta_rule_update): + state *= exp(g) # decay + v_delta = v - state @ k # delta rule + v_delta *= beta # gating + state += v_delta outer k # state update + output = state @ q # readout + """ + bs = q.shape[0] + num_v_heads = layer.num_v_heads + num_k_heads = layer.num_k_heads + + # GQA: expand k/q heads to match v heads + if num_k_heads != num_v_heads: + repeats = num_v_heads // num_k_heads + q = q.repeat_interleave(repeats, dim=1) + k = k.repeat_interleave(repeats, dim=1) + + # All computation in float32 (state is float32, avoids dtype mismatch) + orig_dtype = q.dtype + q = q.float() + k = k.float() + v = v.float() + + # L2 normalize q and k per-head (matching use_qk_l2norm_in_kernel=True) + q = q / (q.norm(dim=-1, keepdim=True) + 1e-6) + k = k / (k.norm(dim=-1, keepdim=True) + 1e-6) + + decay = torch.exp(g.float()) # [bs, num_v_heads] + beta_f = beta.float() # [bs, num_v_heads] + + outputs = [] + for i in range(bs): + idx = cache_indices[i] + state = recurrent_buf[idx] # [H, V, K] float32 + + # Decay + state = state * decay[i].unsqueeze(-1).unsqueeze(-1) + + k_i = k[i] # [H, K] + v_i = v[i] # [H, V] + b_i = beta_f[i] # [H] + q_i = q[i] # [H, K] + + # Delta rule: v_delta = v - state @ k + v_delta = v_i - torch.bmm(state, k_i.unsqueeze(-1)).squeeze(-1) + v_delta = v_delta * b_i.unsqueeze(-1) # gating + + # State update: state += v_delta ⊗ k (outer product in [V, K] layout) + state = state + v_delta.unsqueeze(-1) * k_i.unsqueeze(-2) + recurrent_buf[idx] = state + + # Output: o = state @ q + o_t = torch.bmm(state, q_i.unsqueeze(-1)).squeeze(-1) # [H, V] + outputs.append(o_t) + + return torch.stack(outputs, dim=0).to(orig_dtype) # [bs, H, V] + + # ------------------------------------------------------------------ + # Forward: extend (prefill) + # ------------------------------------------------------------------ + + def forward_extend( + self, + layer: "RadixLinearAttention", + forward_batch: "ForwardBatch", + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + """GDN extend (prefill): multi-token per request. + + Steps: + 1. Gather conv_state from pool for each request + 2. Per-request causal conv1d + 3. Scatter new conv_state back to pool + 4. SiLU → split q,k,v → gating + 5. FlashInfer chunk_gated_delta_rule (or PyTorch fallback) + 6. Scatter final recurrent state back to pool + """ + metadata = self.forward_metadata + cache_indices = metadata.cache_indices + cu_seqlens = metadata.cu_seqlens + gdn_idx = layer.gdn_layer_idx + total_tokens = mixed_qkv.shape[0] + + recurrent_buf, conv_buf = self.gdn_pool.get_layer_state(gdn_idx) + conv_weight = layer.conv_weight # [conv_dim, kernel_size] + K = conv_weight.shape[1] + batch_size = cache_indices.shape[0] + + key_dim = layer.num_k_heads * layer.head_k_dim + value_dim = layer.num_v_heads * layer.head_v_dim + + # --- Per-request causal conv1d --- + conv_out = torch.empty_like(mixed_qkv) # [total_tokens, conv_dim] + + for i in range(batch_size): + start = int(cu_seqlens[i].item()) + end = int(cu_seqlens[i + 1].item()) + seq_len = end - start + if seq_len == 0: + continue + + idx = cache_indices[i] + x = mixed_qkv[start:end] # [seq_len, conv_dim] + prev_state = conv_buf[idx] # [conv_dim, K-1] + + # Pad with previous conv state + x_padded = torch.cat([prev_state.T, x], dim=0) # [K-1+seq_len, conv_dim] + + # Save new conv state (last K-1 tokens) + conv_buf[idx] = x_padded[-(K - 1):].T.clone() + + # Causal conv1d + out = torch.zeros(seq_len, x.shape[1], device=x.device, dtype=x.dtype) + for kk in range(K): + out += x_padded[kk: kk + seq_len] * conv_weight[:, kk] + conv_out[start:end] = out + + # --- SiLU activation --- + conv_out = F.silu(conv_out) + + # --- Split q, k, v --- + q, k, v = conv_out.split([key_dim, key_dim, value_dim], dim=-1) + q = q.view(total_tokens, layer.num_k_heads, layer.head_k_dim) + k = k.view(total_tokens, layer.num_k_heads, layer.head_k_dim) + v = v.view(total_tokens, layer.num_v_heads, layer.head_v_dim) + + # --- GDN gating --- + g, beta = _gdn_gating(a, b, layer.A_log, layer.dt_bias) + + # --- Recurrent computation --- + use_fi, fi_prefill, _ = _get_flashinfer_gdn() + use_fi_extend = use_fi and fi_prefill is not None and mixed_qkv.is_cuda + + if not self._extend_backend_logged: + logger.info( + "GDNAttnBackend: [extend] using backend=%s", + "flashinfer" if use_fi_extend else "pytorch", + ) + self._extend_backend_logged = True + + if use_fi_extend: + # Gather initial states for this batch + init_state = recurrent_buf[cache_indices].to(torch.float32) + # [batch_size, num_v_heads, head_v_dim, head_k_dim] + + alpha = torch.exp(g.to(torch.float32)) + beta_f32 = beta.to(torch.float32) + + # FlashInfer's use_qk_l2norm_in_kernel is silently ignored — + # the flag is declared in the Python wrapper but never forwarded + # to the CUDA kernel. Pre-normalize q and k here, matching + # sglang's approach (l2norm_fwd before calling with False). + q_fi = q / (q.norm(dim=-1, keepdim=True) + 1e-6) + k_fi = k / (k.norm(dim=-1, keepdim=True) + 1e-6) + + output, final_state = fi_prefill( + q=q_fi.contiguous(), + k=k_fi.contiguous(), + v=v.contiguous(), + g=alpha, + beta=beta_f32, + initial_state=init_state, + output_final_state=True, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=False, + ) + + # Scatter final states back to pool + recurrent_buf[cache_indices] = final_state.to(recurrent_buf.dtype) + else: + # PyTorch fallback: per-request sequential scan + output = self._extend_pytorch_fallback( + q, k, v, g, beta, recurrent_buf, cache_indices, cu_seqlens, layer + ) + + return output.reshape(total_tokens, value_dim) + + def _extend_pytorch_fallback( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + recurrent_buf: torch.Tensor, + cache_indices: torch.Tensor, + cu_seqlens: torch.Tensor, + layer: "RadixLinearAttention", + ) -> torch.Tensor: + """Pure PyTorch extend fallback for GDN with delta rule and L2 norm.""" + total_tokens = q.shape[0] + num_v_heads = layer.num_v_heads + num_k_heads = layer.num_k_heads + head_v_dim = layer.head_v_dim + batch_size = cache_indices.shape[0] + + # All computation in float32 + orig_dtype = q.dtype + q = q.float() + k = k.float() + v = v.float() + + # L2 normalize q and k per-head + q = q / (q.norm(dim=-1, keepdim=True) + 1e-6) + k = k / (k.norm(dim=-1, keepdim=True) + 1e-6) + + # GQA expansion + if num_k_heads != num_v_heads: + repeats = num_v_heads // num_k_heads + q = q.repeat_interleave(repeats, dim=1) + k = k.repeat_interleave(repeats, dim=1) + + output = torch.zeros( + total_tokens, num_v_heads, head_v_dim, + device=q.device, dtype=torch.float32, + ) + + for i in range(batch_size): + start = int(cu_seqlens[i].item()) + end = int(cu_seqlens[i + 1].item()) + seq_len = end - start + if seq_len == 0: + continue + + idx = cache_indices[i] + q_seq = q[start:end] + k_seq = k[start:end] + v_seq = v[start:end] + g_seq = g[start:end] + beta_seq = beta[start:end] + + decay = torch.exp(g_seq.float()) # [seq_len, H] + beta_f = beta_seq.float() # [seq_len, H] + state = recurrent_buf[idx].clone() # [H, V, K] float32 + + seq_outputs = [] + for t in range(seq_len): + # Decay + state = state * decay[t].unsqueeze(-1).unsqueeze(-1) + + k_t = k_seq[t] # [H, K] + v_t = v_seq[t] # [H, V] + b_t = beta_f[t] # [H] + q_t = q_seq[t] # [H, K] + + # Delta rule: v_delta = v - state @ k + v_delta = v_t - torch.bmm(state, k_t.unsqueeze(-1)).squeeze(-1) + v_delta = v_delta * b_t.unsqueeze(-1) + + # State update + state = state + v_delta.unsqueeze(-1) * k_t.unsqueeze(-2) + + # Output + o_t = torch.bmm(state, q_t.unsqueeze(-1)).squeeze(-1) + seq_outputs.append(o_t) + + recurrent_buf[idx] = state + output[start:end] = torch.stack(seq_outputs, dim=0) + + return output.to(orig_dtype) + + # ------------------------------------------------------------------ + # Dispatch entry point + # ------------------------------------------------------------------ + + def forward_gdn( + self, + layer: "RadixLinearAttention", + forward_batch: "ForwardBatch", + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + """Route to decode or extend based on forward mode.""" + if forward_batch.forward_mode.is_decode(): + return self.forward_decode(layer, forward_batch, mixed_qkv, a, b) + else: + return self.forward_extend(layer, forward_batch, mixed_qkv, a, b) diff --git a/pymllm/layers/attention/hybrid_backend.py b/pymllm/layers/attention/hybrid_backend.py new file mode 100644 index 000000000..a5628259e --- /dev/null +++ b/pymllm/layers/attention/hybrid_backend.py @@ -0,0 +1,184 @@ +"""Hybrid attention backend -- FlashInfer + GDN for hybrid architectures. + +Wraps a :class:`FlashInferAttnBackend` (for full-attention layers) and a +:class:`GDNAttnBackend` (for GDN linear-attention layers). Dispatches +based on layer type: + +* ``RadixAttention`` calls → delegated to ``full_attn_backend`` +* ``RadixLinearAttention`` calls (via ``forward_gdn``) → delegated to ``gdn_backend`` + +CUDA-graph compatible: delegates all graph lifecycle methods to both +sub-backends. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Optional, Set + +import torch + +from pymllm.layers.attention.attention_backend import AttentionBackend + +if TYPE_CHECKING: + from pymllm.engine.forward_batch import ForwardBatch, ForwardMode + from pymllm.layers.attention.flashinfer_backend import FlashInferAttnBackend + from pymllm.layers.attention.gdn_backend import GDNAttnBackend + from pymllm.layers.attention.radix_attention import RadixAttention + from pymllm.layers.attention.radix_linear_attention import RadixLinearAttention + +logger = logging.getLogger(__name__) + + +class HybridAttnBackend(AttentionBackend): + """Composite attention backend for hybrid full-attention + GDN models. + + Parameters + ---------- + full_attn_backend + FlashInfer backend for standard transformer attention layers. + gdn_backend + GDN backend for linear-attention layers. + full_attn_layer_ids + Set of global layer IDs that use full attention (for logging). + """ + + def __init__( + self, + full_attn_backend: "FlashInferAttnBackend", + gdn_backend: "GDNAttnBackend", + full_attn_layer_ids: Set[int], + ): + self.full_attn_backend = full_attn_backend + self.gdn_backend = gdn_backend + self.full_attn_layer_ids = full_attn_layer_ids + + logger.info( + "HybridAttnBackend created: %d full-attn layers, " + "%d GDN layers", + len(full_attn_layer_ids), + gdn_backend.gdn_pool.num_gdn_layers, + ) + + # ------------------------------------------------------------------ + # Core interface: init_forward_metadata + # ------------------------------------------------------------------ + + def init_forward_metadata(self, forward_batch: "ForwardBatch") -> None: + """Initialize metadata for both sub-backends.""" + self.full_attn_backend.init_forward_metadata(forward_batch) + self.gdn_backend.init_forward_metadata(forward_batch) + + # ------------------------------------------------------------------ + # Full attention: forward_decode / forward_extend + # ------------------------------------------------------------------ + + def forward_decode( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", + forward_batch: "ForwardBatch", + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + """Delegate full-attention decode to FlashInfer backend.""" + return self.full_attn_backend.forward_decode( + q, k, v, layer, forward_batch, save_kv_cache=save_kv_cache, **kwargs + ) + + def forward_extend( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", + forward_batch: "ForwardBatch", + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + """Delegate full-attention extend to FlashInfer backend.""" + return self.full_attn_backend.forward_extend( + q, k, v, layer, forward_batch, save_kv_cache=save_kv_cache, **kwargs + ) + + # ------------------------------------------------------------------ + # GDN linear attention: forward_gdn + # ------------------------------------------------------------------ + + def forward_gdn( + self, + layer: "RadixLinearAttention", + forward_batch: "ForwardBatch", + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + """Delegate GDN computation to the GDN backend.""" + return self.gdn_backend.forward_gdn( + layer=layer, + forward_batch=forward_batch, + mixed_qkv=mixed_qkv, + a=a, + b=b, + ) + + # ------------------------------------------------------------------ + # CUDA-graph interface: delegate to both sub-backends + # ------------------------------------------------------------------ + + def get_cuda_graph_seq_len_fill_value(self) -> int: + """Delegate to the full-attention backend.""" + return self.full_attn_backend.get_cuda_graph_seq_len_fill_value() + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int) -> None: + """Allocate CUDA-graph state for both sub-backends.""" + self.full_attn_backend.init_cuda_graph_state(max_bs, max_num_tokens) + self.gdn_backend.init_cuda_graph_state(max_bs, max_num_tokens) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + forward_mode: "ForwardMode", + ) -> None: + """Set up metadata for CUDA-graph capture in both sub-backends.""" + self.full_attn_backend.init_forward_metadata_capture_cuda_graph( + bs=bs, + num_tokens=num_tokens, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + forward_mode=forward_mode, + ) + self.gdn_backend.init_forward_metadata_capture_cuda_graph( + bs=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + forward_mode: "ForwardMode", + seq_lens_cpu: Optional[torch.Tensor], + ) -> None: + """Update metadata for CUDA-graph replay in both sub-backends.""" + self.full_attn_backend.init_forward_metadata_replay_cuda_graph( + bs=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_sum=seq_lens_sum, + forward_mode=forward_mode, + seq_lens_cpu=seq_lens_cpu, + ) + self.gdn_backend.init_forward_metadata_replay_cuda_graph( + bs=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + ) diff --git a/pymllm/layers/attention/radix_attention.py b/pymllm/layers/attention/radix_attention.py new file mode 100644 index 000000000..114130dbf --- /dev/null +++ b/pymllm/layers/attention/radix_attention.py @@ -0,0 +1,171 @@ +"""RadixAttention -- the attention layer used by pymllm models. + +This module is kept small intentionally: all heavy computation is delegated +to the pluggable ``AttentionBackend`` that is attached to the ``ForwardBatch``. +""" + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Optional + +import torch +from torch import nn + +if TYPE_CHECKING: + from pymllm.engine.forward_batch import ForwardBatch + + +# --------------------------------------------------------------------------- +# AttentionType +# --------------------------------------------------------------------------- + + +class AttentionType(Enum): + """Attention variant used by a :class:`RadixAttention` layer. + + Uses string values so that ``torch.compile`` can treat them as constants. + """ + + # Standard causal self-attention in a decoder layer. + DECODER = "decoder" + + # Bidirectional self-attention for image tokens inside a decoder + # (e.g. VLM visual encoder embedded in the language model). + DECODER_BIDIRECTIONAL = "decoder_bidirectional" + + # Full bidirectional self-attention in an encoder-only model. + ENCODER_ONLY = "encoder_only" + + +# --------------------------------------------------------------------------- +# RadixAttention +# --------------------------------------------------------------------------- + + +class RadixAttention(nn.Module): + """Attention layer that delegates computation to a pluggable backend. + + Each transformer attention layer in a pymllm model creates exactly one + ``RadixAttention`` with a unique ``layer_id``. During the forward pass + the layer looks up the correct KV buffer via ``layer_id`` and calls the + backend attached to the current :class:`~pymllm.engine.forward_batch.ForwardBatch`. + + Parameters + ---------- + num_heads + Number of query attention heads (after any tensor-parallelism + sharding; pass the full count if not using TP). + head_dim + Per-head dimension for query and key projections. + scaling + Softmax pre-scale, typically ``1 / sqrt(head_dim)``. + num_kv_heads + Number of key / value heads (supports GQA / MQA). + layer_id + Zero-based index of this layer within the model. Used to index into + ``KVPool.k_buffer`` / ``v_buffer``. + logit_cap + If > 0, attention logits are soft-capped to this value via a ``tanh`` + gate (used by Gemma2 / Gemma3 style models). Set to ``0.0`` to + disable. + v_head_dim + Per-head dimension of the value projection. Defaults to ``head_dim`` + (i.e. standard square QKV). + sliding_window_size + Sliding-window attention span. ``-1`` means full context (no window). + is_cross_attention + ``True`` for cross-attention layers in encoder-decoder models. + attn_type + One of :class:`AttentionType`. + """ + + def __init__( + self, + num_heads: int, + head_dim: int, + scaling: float, + num_kv_heads: int, + layer_id: int, + logit_cap: float = 0.0, + v_head_dim: int = -1, + sliding_window_size: int = -1, + is_cross_attention: bool = False, + attn_type: AttentionType = AttentionType.DECODER, + ): + super().__init__() + + self.tp_q_head_num: int = num_heads + self.tp_k_head_num: int = num_kv_heads + self.tp_v_head_num: int = num_kv_heads + + self.head_dim: int = head_dim + self.qk_head_dim: int = head_dim + self.v_head_dim: int = v_head_dim if v_head_dim != -1 else head_dim + + self.scaling: float = scaling + self.layer_id: int = layer_id + self.logit_cap: float = logit_cap + self.sliding_window_size: int = ( + sliding_window_size if sliding_window_size is not None else -1 + ) + self.is_cross_attention: bool = is_cross_attention + self.attn_type: AttentionType = attn_type + + # ------------------------------------------------------------------ + # forward + # ------------------------------------------------------------------ + + def forward( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + forward_batch: "ForwardBatch", + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + """Run attention for one batch. + + Parameters + ---------- + q + Query tensor, shape ``[num_tokens, tp_q_head_num * head_dim]`` + (or already reshaped to ``[num_tokens, tp_q_head_num, head_dim]``). + k + Key tensor, same leading dimension as ``q``, shape + ``[num_tokens, tp_k_head_num * qk_head_dim]``. + Pass ``None`` for cross-layer KV sharing (``v`` must also be + ``None`` in this case). + v + Value tensor, shape + ``[num_tokens, tp_v_head_num * v_head_dim]``. + forward_batch + Batch metadata and references to memory pools / backend. + save_kv_cache + When ``False``, skip writing K/V into the pool (useful for draft + models in speculative decoding). + **kwargs + Passed through to the backend (e.g. ``q_rope``, ``k_rope``). + """ + if k is not None: + assert v is not None, "k and v must both be provided or both be None" + k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) + v = v.view(-1, self.tp_v_head_num, self.v_head_dim) + + return forward_batch.attn_backend.forward( + q, k, v, self, forward_batch, save_kv_cache, **kwargs + ) + + def extra_repr(self) -> str: + return ( + f"layer_id={self.layer_id}, " + f"q_heads={self.tp_q_head_num}, " + f"kv_heads={self.tp_k_head_num}, " + f"head_dim={self.head_dim}, " + f"v_head_dim={self.v_head_dim}, " + f"scaling={self.scaling:.4f}, " + f"logit_cap={self.logit_cap}, " + f"sliding_window={self.sliding_window_size}, " + f"attn_type={self.attn_type.value}" + ) diff --git a/pymllm/layers/attention/radix_linear_attention.py b/pymllm/layers/attention/radix_linear_attention.py new file mode 100644 index 000000000..01993163d --- /dev/null +++ b/pymllm/layers/attention/radix_linear_attention.py @@ -0,0 +1,116 @@ +"""RadixLinearAttention -- GDN linear-attention layer for hybrid models. + +Analogous to :class:`RadixAttention` but for GDN (Gated Delta Net) layers. +Stores per-layer GDN parameters and delegates computation to the +:meth:`AttentionBackend.forward_gdn` method on the current +:class:`~pymllm.engine.forward_batch.ForwardBatch`. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch import nn + +if TYPE_CHECKING: + from pymllm.engine.forward_batch import ForwardBatch + + +class RadixLinearAttention(nn.Module): + """GDN linear-attention layer that delegates to the attention backend. + + Each GDN layer in a pymllm model creates one ``RadixLinearAttention`` + with a unique ``layer_id`` and ``gdn_layer_idx``. During forward, it + calls ``forward_batch.attn_backend.forward_gdn(...)`` which routes to + the appropriate GDN backend implementation. + + Parameters + ---------- + layer_id : int + Global zero-based layer index within the model. + gdn_layer_idx : int + Sequential zero-based index among GDN layers only (not global). + Used to index into :class:`~pymllm.mem_cache.memory_pool.GDNPool`. + num_k_heads : int + Number of key heads. + num_v_heads : int + Number of value heads. + head_k_dim : int + Per-head key dimension. + head_v_dim : int + Per-head value dimension. + conv_weight : nn.Parameter + Reference to the GDNConv1d weight parameter. + A_log : nn.Parameter + Log-space decay parameter. + dt_bias : nn.Parameter + Bias for the decay gate. + """ + + def __init__( + self, + layer_id: int, + gdn_layer_idx: int, + num_k_heads: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + conv_weight: nn.Parameter, + A_log: nn.Parameter, + dt_bias: nn.Parameter, + ): + super().__init__() + self.layer_id = layer_id + self.gdn_layer_idx = gdn_layer_idx + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + # Store references to model parameters (not copies) + self.conv_weight = conv_weight + self.A_log = A_log + self.dt_bias = dt_bias + + def forward( + self, + forward_batch: "ForwardBatch", + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + """Delegate GDN computation to the attention backend. + + Parameters + ---------- + forward_batch + Batch metadata with ``attn_backend`` attached. + mixed_qkv + Concatenated Q/K/V projection output before conv1d. + a + Decay gate input, shape ``[num_tokens, num_v_heads]``. + b + Update gate input, shape ``[num_tokens, num_v_heads]``. + + Returns + ------- + torch.Tensor + GDN attention output, shape ``[num_tokens, num_v_heads * head_v_dim]``. + """ + return forward_batch.attn_backend.forward_gdn( + layer=self, + forward_batch=forward_batch, + mixed_qkv=mixed_qkv, + a=a, + b=b, + ) + + def extra_repr(self) -> str: + return ( + f"layer_id={self.layer_id}, " + f"gdn_layer_idx={self.gdn_layer_idx}, " + f"k_heads={self.num_k_heads}, " + f"v_heads={self.num_v_heads}, " + f"k_dim={self.head_k_dim}, " + f"v_dim={self.head_v_dim}" + ) diff --git a/pymllm/layers/base.py b/pymllm/layers/base.py new file mode 100644 index 000000000..3044e2064 --- /dev/null +++ b/pymllm/layers/base.py @@ -0,0 +1,28 @@ +import torch +from torch import nn +from torch.nn import Parameter +from pymllm.layers.utils import set_weight_attrs +from pymllm.quantization.quant_recipe import QuantRecipe +from typing import Optional + + +class MllmBaseLayer(nn.Module): + def __init__(self): + super().__init__() + self.quant_recipe: Optional[QuantRecipe] = None + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + """Load weights into a parameter. + + This is the default implementation that directly copies the loaded weight + into the parameter. Subclasses should override this method to implement + custom loading logic (e.g., tensor parallelism sharding). + + Args: + param: The parameter to load weights into. + loaded_weight: The weight tensor loaded from checkpoint. + """ + param.data.copy_(loaded_weight) + + def forward(self, *args, **kwargs): + raise NotImplementedError("Subclasses must implement forward method") diff --git a/pymllm/backends/qualcomm/transformers/core/__init__.py b/pymllm/layers/custom_event.py similarity index 100% rename from pymllm/backends/qualcomm/transformers/core/__init__.py rename to pymllm/layers/custom_event.py diff --git a/pymllm/layers/embedding.py b/pymllm/layers/embedding.py new file mode 100644 index 000000000..ec99c5b2d --- /dev/null +++ b/pymllm/layers/embedding.py @@ -0,0 +1,160 @@ +import torch +import torch.nn.functional as F +from torch.nn import Parameter + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.utils import set_weight_attrs +from pymllm.orchestrator import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) + + +class VocabParallelEmbedding(MllmBaseLayer): + """Embedding layer with vocabulary parallelism. + + This layer shards the embedding table along the vocabulary dimension + for tensor parallelism. + + Args: + num_embeddings: Size of the vocabulary. + embedding_dim: Size of the embedding vector. + padding_idx: Index for padding token (optional). + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + ): + super().__init__() + + # Get TP info from global state + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + + # Calculate sharded size + if self.num_embeddings % self.tp_size != 0: + raise ValueError( + f"num_embeddings ({num_embeddings}) must be divisible by " + f"tp_size ({self.tp_size})" + ) + + self.num_embeddings_per_partition = divide(num_embeddings, self.tp_size) + + # Create sharded weight + self.weight = Parameter( + torch.empty(self.num_embeddings_per_partition, embedding_dim) + ) + + # Calculate shard range + self.vocab_start_index = self.tp_rank * self.num_embeddings_per_partition + self.vocab_end_index = ( + self.vocab_start_index + self.num_embeddings_per_partition + ) + + # Set weight attributes for loading + set_weight_attrs( + self.weight, + { + "output_dim": 0, # Shard along vocab dimension + "input_dim": 1, # Embedding dimension + "weight_loader": self.weight_loader, + }, + ) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + """Load sharded weights into the parameter. + + Args: + param: The parameter to load weights into. + loaded_weight: The weight tensor loaded from checkpoint (full size). + """ + output_dim = getattr(param, "output_dim", None) + + if output_dim is None or self.tp_size == 1: + # No sharding, direct copy + assert param.data.shape == loaded_weight.shape, ( + f"Shape mismatch: param {param.data.shape} vs " + f"loaded {loaded_weight.shape}" + ) + param.data.copy_(loaded_weight) + else: + # Sharded loading: slice the loaded weight + assert loaded_weight.shape[output_dim] == self.num_embeddings, ( + f"Loaded weight vocab size {loaded_weight.shape[output_dim]} " + f"does not match expected {self.num_embeddings}" + ) + + # Slice along vocab dimension + if output_dim == 0: + shard_weight = loaded_weight[ + self.vocab_start_index : self.vocab_end_index, : + ] + else: + shard_weight = loaded_weight.narrow( + output_dim, + self.vocab_start_index, + self.num_embeddings_per_partition, + ) + + assert param.data.shape == shard_weight.shape, ( + f"Shard shape mismatch: param {param.data.shape} vs " + f"shard {shard_weight.shape}" + ) + param.data.copy_(shard_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the embedding layer with TP support. + + Args: + x: Input tensor of token ids. + + Returns: + Embedded representation (all-reduced across TP group if needed). + """ + local_padding_idx = self.padding_idx + if self.tp_size > 1: + # Create mask for valid vocab range + vocab_mask = (x >= self.vocab_start_index) & (x < self.vocab_end_index) + + # Adjust indices to local vocab space + masked_input = torch.where( + vocab_mask, + x - self.vocab_start_index, + torch.zeros_like(x), # Invalid indices become 0 (will be masked) + ) + # F.embedding expects indices in local weight-table space. + # Only pass padding_idx on the owning rank, remapped to local offset. + if self.padding_idx is not None: + if self.vocab_start_index <= self.padding_idx < self.vocab_end_index: + local_padding_idx = self.padding_idx - self.vocab_start_index + else: + local_padding_idx = None + else: + masked_input = x + vocab_mask = None + + # Lookup embeddings + output = F.embedding( + masked_input.long(), + self.weight, + padding_idx=local_padding_idx, + ) + + # Mask invalid positions (for TP) + if vocab_mask is not None: + output.masked_fill_(~vocab_mask.unsqueeze(-1), 0) + + # All-reduce across TP group + if self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output) + + return output diff --git a/pymllm/layers/gated_delta_net.py b/pymllm/layers/gated_delta_net.py new file mode 100644 index 000000000..3753734d9 --- /dev/null +++ b/pymllm/layers/gated_delta_net.py @@ -0,0 +1,168 @@ +"""Gated Delta Network (GDN) linear attention for Qwen3.5. + +This implements the linear attention mechanism used in Qwen3.5's hybrid +architecture. GDN alternates with standard full-attention layers. + +Core formulation (decode, per-head): + g_t = -exp(A_log) * softplus(a_t + dt_bias) + beta_t = sigmoid(b_t) + state_t = exp(g_t) * state_{t-1} + beta_t * (k_t outer v_t) + output_t = (q_t @ state_t) + +State is externalized into a :class:`~pymllm.mem_cache.memory_pool.GDNPool` +and computation is delegated to the attention backend via +:class:`~pymllm.layers.attention.radix_linear_attention.RadixLinearAttention`. +""" + +from __future__ import annotations + +import logging +from typing import Any, Optional + +import torch +import torch.nn as nn + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.linear import Linear +from pymllm.layers.utils import set_weight_attrs + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Conv1d weight holder +# --------------------------------------------------------------------------- + + +class GDNConv1d(nn.Module): + """Causal 1D convolution weight holder for GDN sequence mixing. + + The actual convolution computation is performed by the GDN backend + using pooled conv states. This module only holds the learnable weight. + """ + + def __init__(self, channels: int, kernel_size: int): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.weight = nn.Parameter(torch.empty(channels, kernel_size)) + + +# --------------------------------------------------------------------------- +# GatedDeltaNet — main GDN layer +# --------------------------------------------------------------------------- + + +class GatedDeltaNet(MllmBaseLayer): + """Gated Delta Network linear attention layer for Qwen3.5. + + State is externalized into a GDNPool and computation is delegated to + the attention backend via RadixLinearAttention. + + Parameters + ---------- + hidden_size : int + Model hidden dimension. + num_k_heads : int + Number of key heads. + num_v_heads : int + Number of value heads. + head_k_dim : int + Per-head key dimension. + head_v_dim : int + Per-head value dimension. + conv_kernel_size : int + Causal conv1d kernel width. + layer_id : int + Global layer index. + gdn_layer_idx : int + Sequential index among GDN layers (0-based). + rms_norm_eps : float + Epsilon for gated RMS normalization. + """ + + def __init__( + self, + hidden_size: int, + num_k_heads: int = 16, + num_v_heads: int = 32, + head_k_dim: int = 128, + head_v_dim: int = 128, + conv_kernel_size: int = 4, + layer_id: int = 0, + gdn_layer_idx: int = 0, + rms_norm_eps: float = 1e-6, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + self.key_dim = head_k_dim * num_k_heads + self.value_dim = head_v_dim * num_v_heads + self.conv_kernel_size = conv_kernel_size + self.layer_id = layer_id + self.gdn_layer_idx = gdn_layer_idx + + # Input projections + self.in_proj_qkv = Linear(hidden_size, self.key_dim * 2 + self.value_dim, bias=False) + self.in_proj_z = Linear(hidden_size, self.value_dim, bias=False) + self.in_proj_a = Linear(hidden_size, num_v_heads, bias=False) + self.in_proj_b = Linear(hidden_size, num_v_heads, bias=False) + + # Causal convolution (weight only — computation is in the backend) + self.conv1d = GDNConv1d(self.key_dim * 2 + self.value_dim, conv_kernel_size) + + # State parameters (must stay float32 for numerical stability) + self.A_log = nn.Parameter(torch.empty(num_v_heads, dtype=torch.float32)) + self.dt_bias = nn.Parameter(torch.ones(num_v_heads, dtype=torch.float32)) + set_weight_attrs(self.A_log, {"weight_loader": self.weight_loader}) + set_weight_attrs(self.dt_bias, {"weight_loader": self.weight_loader}) + + # Gated RMSNorm (mllm-kernel accelerated) + from pymllm.layers.rms_norm_gated import RMSNormGated + self.norm = RMSNormGated(head_v_dim, eps=rms_norm_eps, norm_before_gate=True) + + # Output projection + self.out_proj = Linear(self.value_dim, hidden_size, bias=False) + + # RadixLinearAttention — delegates to the attention backend + from pymllm.layers.attention.radix_linear_attention import RadixLinearAttention + self.attn = RadixLinearAttention( + layer_id=layer_id, + gdn_layer_idx=gdn_layer_idx, + num_k_heads=num_k_heads, + num_v_heads=num_v_heads, + head_k_dim=head_k_dim, + head_v_dim=head_v_dim, + conv_weight=self.conv1d.weight, + A_log=self.A_log, + dt_bias=self.dt_bias, + ) + + def forward( + self, hidden_states: torch.Tensor, forward_batch: Any = None, + ) -> torch.Tensor: + seq_len, _ = hidden_states.shape + + # Input projections + mixed_qkv = self.in_proj_qkv(hidden_states) + z = self.in_proj_z(hidden_states) + a = self.in_proj_a(hidden_states) + b = self.in_proj_b(hidden_states) + + # Delegate to backend via RadixLinearAttention + # The backend handles: conv1d, SiLU, split, gating, recurrent update + attn_out = self.attn(forward_batch, mixed_qkv, a, b) + + # Gated norm + output projection + attn_out = attn_out.view(seq_len, self.num_v_heads, self.head_v_dim) + z = z.view(seq_len, self.num_v_heads, self.head_v_dim) + + attn_flat = attn_out.reshape(-1, self.head_v_dim) + z_flat = z.reshape(-1, self.head_v_dim) + normed = self.norm(attn_flat, z_flat) + normed = normed.view(seq_len, self.num_v_heads, self.head_v_dim) + normed = normed.reshape(seq_len, self.value_dim) + return self.out_proj(normed) diff --git a/pymllm/layers/layer_norm.py b/pymllm/layers/layer_norm.py new file mode 100644 index 000000000..54d94c19e --- /dev/null +++ b/pymllm/layers/layer_norm.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import torch +import flashinfer +from torch.nn import Parameter + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.utils import set_weight_attrs + + +class LayerNorm(MllmBaseLayer): + """LayerNorm layer implemented with FlashInfer kernel.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + # flashinfer.norm.layernorm expects gamma/beta in fp32. + self.weight = Parameter(torch.ones(hidden_size, dtype=torch.float32)) + self.bias = Parameter(torch.zeros(hidden_size, dtype=torch.float32)) + set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) + set_weight_attrs(self.bias, {"weight_loader": self.weight_loader}) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.shape[-1] != self.hidden_size: + raise ValueError( + f"Expected last dim == hidden_size ({self.hidden_size}), " + f"but got input shape {tuple(x.shape)}" + ) + if x.dtype != torch.bfloat16: + raise TypeError( + "flashinfer.norm.layernorm requires bfloat16 input, " + f"but got {x.dtype}" + ) + + if x.dim() == 2: + return flashinfer.norm.layernorm(x, self.weight, self.bias, self.eps) + + original_shape = x.shape + x_2d = x.reshape(-1, self.hidden_size) + out = flashinfer.norm.layernorm(x_2d, self.weight, self.bias, self.eps) + return out.reshape(original_shape) diff --git a/pymllm/layers/linear.py b/pymllm/layers/linear.py new file mode 100644 index 000000000..dc583e931 --- /dev/null +++ b/pymllm/layers/linear.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch.nn import Parameter + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.utils import set_weight_attrs +from pymllm.orchestrator import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) + + +class ColumnParallelLinear(MllmBaseLayer): + """Linear layer with column parallelism (output-dimension sharding). + + The weight matrix is split along the output dimension across TP ranks. + Each rank holds ``out_features / tp_size`` rows of the weight. + + Args: + in_features: Size of each input sample. + out_features: Size of each output sample (before sharding). + bias: If ``True``, adds a learnable bias. + gather_output: If ``True``, all-gather the output across TP ranks + so every rank gets the full ``out_features``. Set to ``False`` + when the next layer is a :class:`RowParallelLinear` that expects + a split input. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + gather_output: bool = True, + ): + super().__init__() + + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + + if out_features % self.tp_size != 0: + raise ValueError( + f"out_features ({out_features}) must be divisible by " + f"tp_size ({self.tp_size})" + ) + self.out_features_per_partition = divide(out_features, self.tp_size) + + self.output_start_index = self.tp_rank * self.out_features_per_partition + self.output_end_index = self.output_start_index + self.out_features_per_partition + + self.weight = Parameter( + torch.empty(self.out_features_per_partition, in_features) + ) + set_weight_attrs( + self.weight, + { + "output_dim": 0, + "input_dim": 1, + "weight_loader": self.weight_loader, + }, + ) + + if bias: + self.bias_flag = True + self.bias = Parameter(torch.empty(self.out_features_per_partition)) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.bias_flag = False + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + """Load sharded weights into the parameter. + + Args: + param: The parameter to load weights into. + loaded_weight: The weight tensor loaded from checkpoint (full size). + """ + output_dim = getattr(param, "output_dim", None) + + if output_dim is None or self.tp_size == 1: + assert param.data.shape == loaded_weight.shape, ( + f"Shape mismatch: param {param.data.shape} vs " + f"loaded {loaded_weight.shape}" + ) + param.data.copy_(loaded_weight) + else: + shard_weight = loaded_weight.narrow( + output_dim, + self.output_start_index, + self.out_features_per_partition, + ) + assert param.data.shape == shard_weight.shape, ( + f"Shard shape mismatch: param {param.data.shape} vs " + f"shard {shard_weight.shape}" + ) + param.data.copy_(shard_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = F.linear(x, self.weight, self.bias) + + if self.gather_output and self.tp_size > 1: + output = tensor_model_parallel_all_gather(output, dim=-1) + + return output + + +class RowParallelLinear(MllmBaseLayer): + """Linear layer with row parallelism (input-dimension sharding). + + The weight matrix is split along the input dimension across TP ranks. + Each rank holds all ``out_features`` rows but only + ``in_features / tp_size`` columns. + + Typically placed after a :class:`ColumnParallelLinear` whose + ``gather_output=False``, so the input is already split. + + Args: + in_features: Size of each input sample (before sharding). + out_features: Size of each output sample. + bias: If ``True``, adds a learnable bias (applied after all-reduce). + reduce_output: If ``True``, all-reduce the output across TP ranks. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + reduce_output: bool = True, + ): + super().__init__() + + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + self.in_features = in_features + self.out_features = out_features + self.reduce_output = reduce_output + + if in_features % self.tp_size != 0: + raise ValueError( + f"in_features ({in_features}) must be divisible by " + f"tp_size ({self.tp_size})" + ) + self.in_features_per_partition = divide(in_features, self.tp_size) + + self.input_start_index = self.tp_rank * self.in_features_per_partition + self.input_end_index = self.input_start_index + self.in_features_per_partition + + self.weight = Parameter( + torch.empty(out_features, self.in_features_per_partition) + ) + set_weight_attrs( + self.weight, + { + "output_dim": 0, + "input_dim": 1, + "weight_loader": self.weight_loader, + }, + ) + + if bias: + self.bias_flag = True + self.bias = Parameter(torch.empty(out_features)) + set_weight_attrs(self.bias, {"weight_loader": self.weight_loader}) + else: + self.bias_flag = False + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + """Load sharded weights into the parameter. + + Args: + param: The parameter to load weights into. + loaded_weight: The weight tensor loaded from checkpoint (full size). + """ + input_dim = getattr(param, "input_dim", None) + + if input_dim is None or self.tp_size == 1: + assert param.data.shape == loaded_weight.shape, ( + f"Shape mismatch: param {param.data.shape} vs " + f"loaded {loaded_weight.shape}" + ) + param.data.copy_(loaded_weight) + else: + shard_weight = loaded_weight.narrow( + input_dim, + self.input_start_index, + self.in_features_per_partition, + ) + assert param.data.shape == shard_weight.shape, ( + f"Shard shape mismatch: param {param.data.shape} vs " + f"shard {shard_weight.shape}" + ) + param.data.copy_(shard_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = F.linear(x, self.weight) + + if self.reduce_output and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output) + + if self.bias is not None: + output = output + self.bias + + return output + + +class Linear(MllmBaseLayer): + """Linear layer with simple quant dispatch.""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.weight = Parameter(torch.empty(out_features, in_features)) + set_weight_attrs( + self.weight, + { + "output_dim": 0, + "input_dim": 1, + "weight_loader": self.weight_loader, + }, + ) + + if bias: + self.bias = Parameter(torch.empty(out_features)) + set_weight_attrs(self.bias, {"weight_loader": self.weight_loader}) + else: + self.register_parameter("bias", None) + + def _forward_torch_linear(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, self.weight, self.bias) + + def _forward_quant_linear(self, x: torch.Tensor) -> torch.Tensor: + # TODO(wch): Implement quantized linear path. + raise NotImplementedError("quant_linear is not implemented yet.") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.quant_recipe is None: + return self._forward_torch_linear(x) + return self._forward_quant_linear(x) diff --git a/pymllm/layers/mlp.py b/pymllm/layers/mlp.py new file mode 100644 index 000000000..1a40db92e --- /dev/null +++ b/pymllm/layers/mlp.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import logging +from typing import Callable, Literal, Optional + +import flashinfer +import torch + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.linear import ColumnParallelLinear, Linear, RowParallelLinear + +logger = logging.getLogger(__name__) + +MLPActivation = Literal["silu", "gelu", "gelu_tanh"] + +_ACTIVATION_MAP: dict[MLPActivation, Callable[..., torch.Tensor]] = { + "silu": flashinfer.activation.silu_and_mul, + "gelu": flashinfer.activation.gelu_and_mul, + "gelu_tanh": flashinfer.activation.gelu_tanh_and_mul, +} + + +def _validate_mlp_args( + hidden_size: int, intermediate_size: int, activation: str +) -> None: + if hidden_size <= 0: + raise ValueError(f"hidden_size must be > 0, but got {hidden_size}") + if intermediate_size <= 0: + raise ValueError( + f"intermediate_size must be > 0, but got {intermediate_size}" + ) + if activation not in _ACTIVATION_MAP: + raise ValueError( + f"Unsupported activation '{activation}'. " + f"Expected one of: {list(_ACTIVATION_MAP)}" + ) + + +def _run_gated_activation( + gate_up: torch.Tensor, + intermediate_size: int, + activation: MLPActivation, + enable_pdl: Optional[bool], +) -> torch.Tensor: + if gate_up.shape[-1] != 2 * intermediate_size: + raise ValueError( + "Expected last dim of gate_up tensor to be " + f"{2 * intermediate_size}, but got {gate_up.shape[-1]}" + ) + return _ACTIVATION_MAP[activation](gate_up, enable_pdl=enable_pdl) + + +class MLP(MllmBaseLayer): + """Feed-forward MLP block with FlashInfer fused gated activations. + + Non-parallel version (TP=1). Uses :class:`Linear` for all projections. + + Supported activations: ``silu``, ``gelu``, ``gelu_tanh``. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + activation: MLPActivation = "silu", + use_fused_gate_up_proj: bool = True, + use_bias_gate_up: bool = False, + use_bias_down: bool = False, + enable_pdl: Optional[bool] = None, + ): + super().__init__() + _validate_mlp_args(hidden_size, intermediate_size, activation) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.activation = activation + self.use_fused_gate_up_proj = use_fused_gate_up_proj + self.enable_pdl = enable_pdl + + if not use_fused_gate_up_proj: + logger.warning( + "MLP with use_fused_gate_up_proj=False uses a lower-efficiency path. " + "Use use_fused_gate_up_proj=True for better performance.", + ) + + if use_fused_gate_up_proj: + self.gate_up_proj = Linear( + hidden_size, 2 * intermediate_size, bias=use_bias_gate_up, + ) + self.gate_proj = None + self.up_proj = None + else: + self.gate_up_proj = None + self.gate_proj = Linear( + hidden_size, intermediate_size, bias=use_bias_gate_up, + ) + self.up_proj = Linear( + hidden_size, intermediate_size, bias=use_bias_gate_up, + ) + + self.down_proj = Linear( + intermediate_size, hidden_size, bias=use_bias_down, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.shape[-1] != self.hidden_size: + raise ValueError( + f"Expected last dim == hidden_size ({self.hidden_size}), " + f"but got input shape {tuple(x.shape)}" + ) + + if self.use_fused_gate_up_proj: + assert self.gate_up_proj is not None + gate_up = self.gate_up_proj(x) + else: + assert self.gate_proj is not None and self.up_proj is not None + gate_up = torch.cat([self.gate_proj(x), self.up_proj(x)], dim=-1) + + hidden = _run_gated_activation( + gate_up, self.intermediate_size, self.activation, self.enable_pdl, + ) + return self.down_proj(hidden) + + +class ParallelMLP(MllmBaseLayer): + """Tensor-parallel MLP with column-sharded intermediate dimension. + + Projection layout (Megatron-style): + + - ``gate_proj``: :class:`ColumnParallelLinear` + ``(hidden_size → intermediate_size, gather_output=False)`` + - ``up_proj``: :class:`ColumnParallelLinear` + ``(hidden_size → intermediate_size, gather_output=False)`` + - ``down_proj``: :class:`RowParallelLinear` + ``(intermediate_size → hidden_size, reduce_output=True)`` + + Gate and up projections are kept separate so that each TP rank holds a + correctly paired ``[gate_shard, up_shard]`` for the gated activation. + + Cost: **1 all-reduce** (inside ``down_proj``). + + Input shape : ``(*, hidden_size)`` — full / replicated. + Output shape: ``(*, hidden_size)`` — full / replicated. + + Args: + hidden_size: Model hidden dimension. + intermediate_size: Intermediate (expanded) dimension **before** TP + sharding. + activation: Gated activation type. + use_bias_gate_up: Add bias to the gate/up projections. + use_bias_down: Add bias to the down projection. + enable_pdl: FlashInfer PDL flag. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + activation: MLPActivation = "silu", + use_bias_gate_up: bool = False, + use_bias_down: bool = False, + enable_pdl: Optional[bool] = None, + ): + super().__init__() + _validate_mlp_args(hidden_size, intermediate_size, activation) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.activation = activation + self.enable_pdl = enable_pdl + + self.gate_proj = ColumnParallelLinear( + hidden_size, intermediate_size, + bias=use_bias_gate_up, gather_output=False, + ) + self.up_proj = ColumnParallelLinear( + hidden_size, intermediate_size, + bias=use_bias_gate_up, gather_output=False, + ) + + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, + bias=use_bias_down, reduce_output=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.shape[-1] != self.hidden_size: + raise ValueError( + f"Expected last dim == hidden_size ({self.hidden_size}), " + f"but got input shape {tuple(x.shape)}" + ) + + gate_up = torch.cat([self.gate_proj(x), self.up_proj(x)], dim=-1) + + shard_inter = self.down_proj.in_features_per_partition + hidden = _run_gated_activation( + gate_up, shard_inter, self.activation, self.enable_pdl, + ) + return self.down_proj(hidden) diff --git a/pymllm/layers/rms_norm.py b/pymllm/layers/rms_norm.py new file mode 100644 index 000000000..b20b36f30 --- /dev/null +++ b/pymllm/layers/rms_norm.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import Optional, Tuple, Union + +import torch +import flashinfer +from torch.nn import Parameter + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.utils import set_weight_attrs + + +class RMSNorm(MllmBaseLayer): + """RMSNorm layer implemented with FlashInfer kernel.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + self.weight = Parameter(torch.empty(hidden_size)) + set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + flashinfer.norm.fused_add_rmsnorm(x, residual, self.weight.data, self.eps) + return x, residual + + if x.shape[-1] != self.hidden_size: + raise ValueError( + f"Expected last dim == hidden_size ({self.hidden_size}), " + f"but got input shape {tuple(x.shape)}" + ) + + # FlashInfer rmsnorm accepts 2D/3D input; flatten higher-rank tensors to 2D. + if x.dim() in (2, 3): + return flashinfer.norm.rmsnorm(x, self.weight, self.eps) + + original_shape = x.shape + x_2d = x.reshape(-1, self.hidden_size) + out = flashinfer.norm.rmsnorm(x_2d, self.weight, self.eps) + return out.reshape(original_shape) + + +class GemmaRMSNorm(MllmBaseLayer): + """Gemma-style RMSNorm layer implemented with FlashInfer kernel.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + self.weight = Parameter(torch.empty(hidden_size)) + set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + flashinfer.norm.gemma_fused_add_rmsnorm( + x, residual, self.weight.data, self.eps + ) + return x, residual + + if x.shape[-1] != self.hidden_size: + raise ValueError( + f"Expected last dim == hidden_size ({self.hidden_size}), " + f"but got input shape {tuple(x.shape)}" + ) + + # gemma_rmsnorm is defined on 2D input; flatten other ranks to 2D. + if x.dim() == 2: + return flashinfer.norm.gemma_rmsnorm(x, self.weight, self.eps) + + original_shape = x.shape + x_2d = x.reshape(-1, self.hidden_size) + out = flashinfer.norm.gemma_rmsnorm(x_2d, self.weight, self.eps) + return out.reshape(original_shape) diff --git a/pymllm/layers/rms_norm_gated.py b/pymllm/layers/rms_norm_gated.py new file mode 100644 index 000000000..caec9b88d --- /dev/null +++ b/pymllm/layers/rms_norm_gated.py @@ -0,0 +1,154 @@ +"""Gated RMSNorm layer for Qwen3.5 GDN attention. + +Computes ``rmsnorm(x, weight, eps) * silu(z)`` using a fused CUDA kernel +from mllm-kernel. Falls back to PyTorch when the kernel is unavailable. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.utils import set_weight_attrs + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Try to load the mllm-kernel fused CUDA implementation +# --------------------------------------------------------------------------- +_HAS_MLLM_KERNEL_CUDA = False +try: + from mllm_kernel.cuda.jit.rms_norm_gated import ( + rms_norm_gated as _mllm_rms_norm_gated, + ) + + _HAS_MLLM_KERNEL_CUDA = True +except Exception: + _mllm_rms_norm_gated = None + + +# --------------------------------------------------------------------------- +# Pure-PyTorch fallback +# --------------------------------------------------------------------------- + + +def _rms_norm_gated_pytorch( + x: torch.Tensor, + weight: torch.Tensor, + z: Optional[torch.Tensor] = None, + eps: float = 1e-6, + norm_before_gate: bool = True, +) -> torch.Tensor: + """Pure-PyTorch reference implementation.""" + dtype = x.dtype + x_fp32 = x.float() + w_fp32 = weight.float() + z_fp32 = z.float() if z is not None else None + + if z_fp32 is not None and not norm_before_gate: + x_fp32 = x_fp32 * F.silu(z_fp32) + + variance = x_fp32.pow(2).mean(dim=-1, keepdim=True) + rstd = torch.rsqrt(variance + eps) + out = x_fp32 * rstd * w_fp32 + + if z_fp32 is not None and norm_before_gate: + out = out * F.silu(z_fp32) + + return out.to(dtype) + + +# --------------------------------------------------------------------------- +# Unified dispatch +# --------------------------------------------------------------------------- + + +def rms_norm_gated( + x: torch.Tensor, + weight: torch.Tensor, + z: Optional[torch.Tensor] = None, + eps: float = 1e-6, + norm_before_gate: bool = True, +) -> torch.Tensor: + """Compute (optionally gated) RMS normalization. + + Uses the fused mllm-kernel CUDA implementation when available, + otherwise falls back to a pure-PyTorch implementation. + """ + if _HAS_MLLM_KERNEL_CUDA and x.is_cuda: + return _mllm_rms_norm_gated(x, weight, z=z, eps=eps) + return _rms_norm_gated_pytorch( + x, weight, z=z, eps=eps, norm_before_gate=norm_before_gate, + ) + + +# --------------------------------------------------------------------------- +# nn.Module wrapper +# --------------------------------------------------------------------------- + + +class RMSNormGated(MllmBaseLayer): + """Gated RMS Normalization layer for Qwen3.5 GDN attention. + + Computes:: + + output = rmsnorm(x, weight) * silu(z) # z is not None + output = rmsnorm(x, weight) # z is None + + Uses a fused CUDA kernel from mllm-kernel for maximum throughput. + + Parameters + ---------- + hidden_size : int + Dimensionality of the input (and weight vector). + eps : float + Small constant for numerical stability. + norm_before_gate : bool + If ``True`` (default): ``rmsnorm(x) * silu(z)``. + If ``False``: ``rmsnorm(x * silu(z))``. + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + group_size: Optional[int] = None, + norm_before_gate: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.norm_before_gate = norm_before_gate + + factory_kwargs = {} + if device is not None: + factory_kwargs["device"] = device + if dtype is not None: + factory_kwargs["dtype"] = dtype + + self.weight = Parameter(torch.ones(hidden_size, **factory_kwargs)) + set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) + + def forward( + self, + x: torch.Tensor, + z: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return rms_norm_gated( + x, self.weight, z=z, eps=self.eps, + norm_before_gate=self.norm_before_gate, + ) + + def extra_repr(self) -> str: + return ( + f"hidden_size={self.hidden_size}, eps={self.eps}, " + f"norm_before_gate={self.norm_before_gate}" + ) diff --git a/pymllm/layers/rope.py b/pymllm/layers/rope.py new file mode 100644 index 000000000..94f89b20d --- /dev/null +++ b/pymllm/layers/rope.py @@ -0,0 +1,401 @@ +from __future__ import annotations + +from typing import List, Optional, Tuple + +import torch +import flashinfer + + +def apply_rope( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + inplace: bool = False, + rotary_dim: Optional[int] = None, + interleave: bool = False, + rope_scale: float = 1.0, + rope_theta: float = 1e4, +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). + + cos/sin values are computed on the fly inside the kernel. Position offsets + are provided per-segment via ``indptr`` and ``offsets``. + + Args: + q: Query ragged tensor, shape ``(nnz, num_q_heads, head_dim)``. + k: Key ragged tensor, shape ``(nnz, num_k_heads, head_dim)``. + indptr: Indptr tensor, shape ``(batch_size + 1,)``. The i-th segment + spans ``q[indptr[i]:indptr[i+1]]``. + offsets: Relative position offsets per segment, shape ``(batch_size,)``. + inplace: If ``True``, apply RoPE in-place and return ``None``. + If ``False``, return new ``(q_rope, k_rope)`` tensors. + rotary_dim: Number of dimensions to apply RoPE to. ``None`` means + the entire ``head_dim``. + interleave: If ``True``, rotate even/odd dims (``[..., ::2]`` / + ``[..., 1::2]``). If ``False``, rotate first/second half dims. + rope_scale: Scaling factor for position indices. + rope_theta: Base frequency theta. + + Returns: + ``None`` when *inplace* is ``True``, otherwise a tuple + ``(q_rope, k_rope)`` of rotated tensors with the same shapes as + the inputs. + """ + if inplace: + flashinfer.rope.apply_rope_inplace( + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + ) + return None + + return flashinfer.rope.apply_rope( + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + ) + + +def apply_llama31_rope( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + inplace: bool = False, + rotary_dim: Optional[int] = None, + interleave: bool = False, + rope_scale: float = 8.0, + rope_theta: float = 5e5, + low_freq_factor: float = 1.0, + high_freq_factor: float = 4.0, + old_context_len: int = 8192, +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Apply Llama 3.1 style rotary embedding to a batch of queries/keys. + + This variant adjusts frequencies with ``low_freq_factor``, + ``high_freq_factor``, and ``old_context_len`` following the Llama 3.1 + RoPE recipe. cos/sin values are computed on the fly. + + Args: + q: Query ragged tensor, shape ``(nnz, num_q_heads, head_dim)``. + k: Key ragged tensor, shape ``(nnz, num_k_heads, head_dim)``. + indptr: Indptr tensor, shape ``(batch_size + 1,)``. + offsets: Relative position offsets per segment, shape ``(batch_size,)``. + inplace: If ``True``, apply in-place and return ``None``. + rotary_dim: Number of dimensions to apply RoPE to. ``None`` means + the entire ``head_dim``. + interleave: If ``True``, rotate even/odd dims; otherwise first/second + half dims. + rope_scale: Scaling factor for position indices (default ``8``). + rope_theta: Base frequency theta (default ``5e5``). + low_freq_factor: Low frequency factor for Llama 3.1 RoPE. + high_freq_factor: High frequency factor for Llama 3.1 RoPE. + old_context_len: Original context length for Llama 3.1 RoPE. + + Returns: + ``None`` when *inplace* is ``True``, otherwise ``(q_rope, k_rope)``. + """ + if inplace: + flashinfer.rope.apply_llama31_rope_inplace( + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor, + old_context_len=old_context_len, + ) + return None + + return flashinfer.rope.apply_llama31_rope( + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor, + old_context_len=old_context_len, + ) + + +def apply_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + pos_ids: torch.Tensor, + inplace: bool = False, + rotary_dim: Optional[int] = None, + interleave: bool = False, + rope_scale: float = 1.0, + rope_theta: float = 1e4, +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Apply rotary embedding using explicit per-token position IDs. + + Unlike :func:`apply_rope` which derives positions from ``indptr`` / + ``offsets``, this function takes a flat ``pos_ids`` tensor that supplies + an explicit position for every token. + + Args: + q: Query tensor, shape ``(nnz, num_q_heads, head_dim)``. + k: Key tensor, shape ``(nnz, num_k_heads, head_dim)``. + pos_ids: Position indices, shape ``(nnz,)``. + inplace: If ``True``, apply in-place and return ``None``. + rotary_dim: Number of dimensions to apply RoPE to. + interleave: Interleaved layout flag. + rope_scale: Scaling factor for position indices. + rope_theta: Base frequency theta. + + Returns: + ``None`` when *inplace* is ``True``, otherwise ``(q_rope, k_rope)``. + """ + if inplace: + flashinfer.rope.apply_rope_pos_ids_inplace( + q, + k, + pos_ids, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + ) + return None + + return flashinfer.rope.apply_rope_pos_ids( + q, + k, + pos_ids, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + ) + + +def apply_llama31_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + pos_ids: torch.Tensor, + inplace: bool = False, + rotary_dim: Optional[int] = None, + interleave: bool = False, + rope_scale: float = 8.0, + rope_theta: float = 5e5, + low_freq_factor: float = 1.0, + high_freq_factor: float = 4.0, + old_context_len: int = 8192, +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Apply Llama 3.1 style RoPE using explicit per-token position IDs. + + Combines Llama 3.1 frequency adjustments with explicit ``pos_ids``. + + Args: + q: Query tensor, shape ``(nnz, num_q_heads, head_dim)``. + k: Key tensor, shape ``(nnz, num_k_heads, head_dim)``. + pos_ids: Position indices, shape ``(nnz,)``. + inplace: If ``True``, apply in-place and return ``None``. + rotary_dim: Number of dimensions to apply RoPE to. + interleave: Interleaved layout flag. + rope_scale: Scaling factor (default ``8``). + rope_theta: Base frequency theta (default ``5e5``). + low_freq_factor: Low frequency factor for Llama 3.1 RoPE. + high_freq_factor: High frequency factor for Llama 3.1 RoPE. + old_context_len: Original context length for Llama 3.1 RoPE. + + Returns: + ``None`` when *inplace* is ``True``, otherwise ``(q_rope, k_rope)``. + """ + if inplace: + flashinfer.rope.apply_llama31_rope_pos_ids_inplace( + q, + k, + pos_ids, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor, + old_context_len=old_context_len, + ) + return None + + return flashinfer.rope.apply_llama31_rope_pos_ids( + q, + k, + pos_ids, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor, + old_context_len=old_context_len, + ) + + +def apply_rope_with_cos_sin_cache( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + inplace: bool = False, + is_neox: bool = True, +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Apply rotary embedding with precomputed cos/sin cache. + + Compatible with SGL/vLLM implementations. Note that ``query`` and ``key`` + use a **flattened** head layout ``(nnz, num_heads * head_size)`` instead + of the 3-D layout used by the other ``apply_rope*`` functions. + + Args: + positions: Position indices, shape ``(nnz,)``. + query: Query tensor, shape ``(nnz, num_q_heads * head_size)``. + key: Key tensor, shape ``(nnz, num_k_heads * head_size)``. + head_size: Size of each attention head. + cos_sin_cache: Precomputed cos/sin tensor, shape + ``(max_seq_len, rotary_dim)``. The first half of ``rotary_dim`` + stores cosine values, the second half stores sine values. + inplace: If ``True``, apply in-place and return ``None``. + is_neox: If ``True`` (default), use GPT-NeoX style (rotate + first/second half dims). If ``False``, use interleaved style + (rotate even/odd dims). + + Returns: + ``None`` when *inplace* is ``True``, otherwise + ``(query_out, key_out)`` with the same shapes as the inputs. + """ + if inplace: + flashinfer.rope.apply_rope_with_cos_sin_cache_inplace( + positions, + query, + key, + head_size, + cos_sin_cache, + is_neox=is_neox, + ) + return None + + return flashinfer.rope.apply_rope_with_cos_sin_cache( + positions, + query, + key, + head_size, + cos_sin_cache, + is_neox=is_neox, + ) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate the second half of the last dimension into the first half (neox-style).""" + half = x.shape[-1] // 2 + return torch.cat((-x[..., half:], x[..., :half]), dim=-1) + + +def apply_mrope( + q: torch.Tensor, + k: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + mrope_section: List[int], + mrope_interleaved: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply multi-dimensional rotary position embedding (M-RoPE). + + Used by Qwen3-VL which assigns independent (t, h, w) position indices to + each token. For text tokens all three indices are the same sequential + value; for image tokens they follow the spatial grid layout. + + Args: + q: Query tensor, shape ``(T, num_q_heads, head_dim)``. + k: Key tensor, shape ``(T, num_kv_heads, head_dim)``. + positions: 3-D position IDs, shape ``(3, T)`` — rows are + ``(temporal, height, width)`` position indices. + cos_sin_cache: Precomputed cache, shape ``(max_pos, head_dim)``. + The first ``head_dim // 2`` columns are cosine values and the + remaining columns are sine values, each for frequencies + ``0, 1, ..., head_dim // 2 - 1``. + mrope_section: Three integers ``[s_t, s_h, s_w]`` that partition + the ``head_dim // 2`` rotary frequency dimensions among the + temporal, height, and width components. + ``sum(mrope_section)`` must equal ``head_dim // 2``. + mrope_interleaved: When ``True`` (Qwen3-VL default), uses the + interleaved layout where frequency dimensions are cycled + ``(t, h, w, t, h, w, ...)`` rather than grouped consecutively. + + Returns: + ``(q_rope, k_rope)`` with the same shapes as the inputs. + """ + rotary_dim = cos_sin_cache.shape[-1] # = head_dim + half_dim = rotary_dim // 2 + + # Look up cos/sin for each of the 3 position dimensions. + # positions: [3, T] => cos_sin: [3, T, rotary_dim] + cos_sin = cos_sin_cache[positions] + cos = cos_sin[..., :half_dim] # [3, T, half_dim] + sin = cos_sin[..., half_dim:] # [3, T, half_dim] + + if mrope_interleaved: + # Interleaved layout (Qwen3-VL): within the first + # mrope_section[1]*3 frequency dims, indices cycle (t, h, w). + # Remaining dims (indices >= span) all use the temporal position. + # Matches SGLang's apply_interleaved_rope. + cos_merged = cos[0].clone() # start with temporal; shape [T, half_dim] + sin_merged = sin[0].clone() + span_h = mrope_section[1] * 3 + span_w = mrope_section[2] * 3 + cos_merged[..., 1:span_h:3] = cos[1, ..., 1:span_h:3] + cos_merged[..., 2:span_w:3] = cos[2, ..., 2:span_w:3] + sin_merged[..., 1:span_h:3] = sin[1, ..., 1:span_h:3] + sin_merged[..., 2:span_w:3] = sin[2, ..., 2:span_w:3] + else: + # Non-interleaved (Qwen2-VL style): consecutive frequency sections. + cos_sects = cos.split(mrope_section, dim=-1) # list of [T, s_i] + sin_sects = sin.split(mrope_section, dim=-1) + # Section i picks its cos/sin from positions[i] + cos_merged = torch.cat( + [cos_sects[i][i] for i in range(3)], dim=-1 + ) # [T, half_dim] + sin_merged = torch.cat( + [sin_sects[i][i] for i in range(3)], dim=-1 + ) # [T, half_dim] + + # Expand to full rotary_dim for the neox-style rotation formula: + # q_rot = q * cos_full + rotate_half(q) * sin_full + cos_full = cos_merged.repeat(1, 2) # [T, rotary_dim] + sin_full = sin_merged.repeat(1, 2) # [T, rotary_dim] + cos_4d = cos_full.unsqueeze(1) # [T, 1, rotary_dim] -- broadcasts over heads + sin_4d = sin_full.unsqueeze(1) + + q_rot = q[..., :rotary_dim] * cos_4d + _rotate_half(q[..., :rotary_dim]) * sin_4d + k_rot = k[..., :rotary_dim] * cos_4d + _rotate_half(k[..., :rotary_dim]) * sin_4d + + q_out = ( + torch.cat([q_rot, q[..., rotary_dim:]], dim=-1) + if rotary_dim < q.shape[-1] + else q_rot + ) + k_out = ( + torch.cat([k_rot, k[..., rotary_dim:]], dim=-1) + if rotary_dim < k.shape[-1] + else k_rot + ) + return q_out, k_out diff --git a/pymllm/layers/sampling.py b/pymllm/layers/sampling.py new file mode 100644 index 000000000..26c769ffd --- /dev/null +++ b/pymllm/layers/sampling.py @@ -0,0 +1,776 @@ +"""Sampling operations with FlashInfer acceleration and PyTorch fallback. + +This module wraps all flashinfer.sampling APIs and provides pure-PyTorch +fallback implementations so that the rest of the codebase can import from +here without worrying about whether FlashInfer is installed. +""" + +from __future__ import annotations + +import logging +from typing import Optional, Tuple, Union + +import torch + +logger = logging.getLogger(__name__) + +try: + import flashinfer.sampling as _fi_sampling + + _HAS_FLASHINFER = True +except ImportError: + _HAS_FLASHINFER = False + logger.warning("flashinfer not found, falling back to PyTorch sampling kernels") + + +# --------------------------------------------------------------------------- +# Helper utilities (torch fallback) +# --------------------------------------------------------------------------- + + +def _resolve_indices( + data: torch.Tensor, indices: Optional[torch.Tensor] +) -> torch.Tensor: + """If *indices* is given, gather rows from *data* accordingly.""" + if indices is None: + return data + return data[indices.long()] + + +def _to_scalar_or_tensor( + value: Union[torch.Tensor, float, int], + batch_size: int, + device: torch.device, +) -> torch.Tensor: + """Broadcast a scalar or per-batch tensor to shape ``(batch_size,)``.""" + if isinstance(value, (int, float)): + return torch.full((batch_size,), value, device=device, dtype=torch.float32) + return value.to(device=device, dtype=torch.float32) + + +# --------------------------------------------------------------------------- +# softmax +# --------------------------------------------------------------------------- + + +def softmax( + logits: torch.Tensor, + temperature: Optional[Union[torch.Tensor, float]] = None, + enable_pdl: Optional[bool] = None, +) -> torch.Tensor: + """Safe softmax with optional temperature scaling. + + Parameters + ---------- + logits : torch.Tensor + Shape ``(batch_size, num_classes)``. + temperature : Optional[Union[torch.Tensor, float]] + Scalar or per-request ``(batch_size,)`` temperature. + enable_pdl : Optional[bool] + FlashInfer PDL flag (ignored in fallback). + + Returns + ------- + torch.Tensor + Probabilities with the same shape as *logits*. + """ + # Clamp temperature to avoid division by zero (temperature=0 → greedy). + # Replace 0 with 1 here; the caller (ModelRunner.sample) handles + # temperature=0 via argmax before reaching this path. + if temperature is not None: + if isinstance(temperature, torch.Tensor): + temperature = temperature.clamp(min=1e-6) + elif temperature < 1e-6: + temperature = 1.0 # effectively no scaling; caller uses argmax + + if _HAS_FLASHINFER: + return _fi_sampling.softmax( + logits, temperature=temperature, enable_pdl=enable_pdl + ) + + if temperature is not None: + if isinstance(temperature, (int, float)): + logits = logits / temperature + else: + logits = logits / temperature.unsqueeze(-1) + return torch.softmax(logits, dim=-1) + + +# --------------------------------------------------------------------------- +# sampling_from_probs +# --------------------------------------------------------------------------- + + +def sampling_from_probs( + probs: torch.Tensor, + indices: Optional[torch.Tensor] = None, + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Category sampling from probabilities. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)`` or ``(unique_batch_size, num_classes)`` + when *indices* is provided. + indices : Optional[torch.Tensor] + Maps each output to a row in *probs*. + deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.sampling_from_probs( + probs, + indices=indices, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + p = _resolve_indices(probs, indices) + samples = torch.multinomial(p.float(), num_samples=1, generator=generator).squeeze( + -1 + ) + return samples.to(torch.int32) + + +# --------------------------------------------------------------------------- +# sampling_from_logits +# --------------------------------------------------------------------------- + + +def sampling_from_logits( + logits: torch.Tensor, + indices: Optional[torch.Tensor] = None, + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Category sampling from logits (applies softmax internally). + + Parameters + ---------- + logits : torch.Tensor + ``(batch_size, num_classes)``. + indices, deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.sampling_from_logits( + logits, + indices=indices, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + probs = torch.softmax(logits.float(), dim=-1) + return sampling_from_probs( + probs, + indices=indices, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + ) + + +# --------------------------------------------------------------------------- +# top_p_sampling_from_probs +# --------------------------------------------------------------------------- + + +def top_p_sampling_from_probs( + probs: torch.Tensor, + top_p: Union[torch.Tensor, float], + indices: Optional[torch.Tensor] = None, + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Top-p (nucleus) sampling from probabilities. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)``. + top_p : Union[torch.Tensor, float] + Top-p threshold. + indices, deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_p_sampling_from_probs( + probs, + top_p, + indices=indices, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + p = _resolve_indices(probs, indices).float() + renormed = _torch_top_p_renorm_probs(p, top_p) + samples = torch.multinomial(renormed, num_samples=1, generator=generator).squeeze( + -1 + ) + return samples.to(torch.int32) + + +# --------------------------------------------------------------------------- +# top_k_sampling_from_probs +# --------------------------------------------------------------------------- + + +def top_k_sampling_from_probs( + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], + indices: Optional[torch.Tensor] = None, + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Top-k sampling from probabilities. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)``. + top_k : Union[torch.Tensor, int] + Top-k threshold. + indices, deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_k_sampling_from_probs( + probs, + top_k, + indices=indices, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + p = _resolve_indices(probs, indices).float() + renormed = _torch_top_k_renorm_probs(p, top_k) + samples = torch.multinomial(renormed, num_samples=1, generator=generator).squeeze( + -1 + ) + return samples.to(torch.int32) + + +# --------------------------------------------------------------------------- +# min_p_sampling_from_probs +# --------------------------------------------------------------------------- + + +def min_p_sampling_from_probs( + probs: torch.Tensor, + min_p: Union[torch.Tensor, float], + indices: Optional[torch.Tensor] = None, + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Min-p sampling from probabilities. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)``. + min_p : Union[torch.Tensor, float] + Min-p threshold. + indices, deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.min_p_sampling_from_probs( + probs, + min_p, + indices=indices, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + p = _resolve_indices(probs, indices).float() + batch_size = p.shape[0] + min_p_t = _to_scalar_or_tensor(min_p, batch_size, p.device) + # min-p: keep tokens whose probability >= min_p * max_prob + max_probs = p.max(dim=-1, keepdim=True).values # (B,1) + threshold = min_p_t.unsqueeze(-1) * max_probs # (B,1) + mask = p < threshold + filtered = p.clone() + filtered[mask] = 0.0 + # renormalize + sums = filtered.sum(dim=-1, keepdim=True) + sums = sums.clamp(min=1e-8) + filtered = filtered / sums + samples = torch.multinomial(filtered, num_samples=1, generator=generator).squeeze( + -1 + ) + return samples.to(torch.int32) + + +# --------------------------------------------------------------------------- +# top_k_top_p_sampling_from_logits +# --------------------------------------------------------------------------- + + +def top_k_top_p_sampling_from_logits( + logits: torch.Tensor, + top_k: Union[torch.Tensor, int], + top_p: Union[torch.Tensor, float], + indices: Optional[torch.Tensor] = None, + filter_apply_order: str = "top_k_first", + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Top-k + top-p sampling from pre-softmax logits. + + Parameters + ---------- + logits : torch.Tensor + ``(batch_size, num_classes)``. + top_k : Union[torch.Tensor, int] + top_p : Union[torch.Tensor, float] + filter_apply_order : str + ``"top_k_first"`` or ``"joint"``. + indices, deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_k_top_p_sampling_from_logits( + logits, + top_k, + top_p, + indices=indices, + filter_apply_order=filter_apply_order, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + probs = torch.softmax(logits.float(), dim=-1) + return top_k_top_p_sampling_from_probs( + probs, + top_k, + top_p, + indices=indices, + filter_apply_order=filter_apply_order, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + ) + + +# --------------------------------------------------------------------------- +# top_k_top_p_sampling_from_probs +# --------------------------------------------------------------------------- + + +def top_k_top_p_sampling_from_probs( + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], + top_p: Union[torch.Tensor, float], + indices: Optional[torch.Tensor] = None, + filter_apply_order: str = "top_k_first", + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Top-k + top-p sampling from probabilities. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)``. + top_k : Union[torch.Tensor, int] + top_p : Union[torch.Tensor, float] + filter_apply_order : str + ``"top_k_first"`` or ``"joint"``. + indices, deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_k_top_p_sampling_from_probs( + probs, + top_k, + top_p, + indices=indices, + filter_apply_order=filter_apply_order, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + p = _resolve_indices(probs, indices).float() + if filter_apply_order == "top_k_first": + p = _torch_top_k_renorm_probs(p, top_k) + p = _torch_top_p_renorm_probs(p, top_p) + else: + # joint: apply both filters simultaneously + p = _torch_top_k_renorm_probs(p, top_k) + p = _torch_top_p_renorm_probs(p, top_p) + samples = torch.multinomial(p, num_samples=1, generator=generator).squeeze(-1) + return samples.to(torch.int32) + + +# --------------------------------------------------------------------------- +# top_p_renorm_probs +# --------------------------------------------------------------------------- + + +def top_p_renorm_probs( + probs: torch.Tensor, + top_p: Union[torch.Tensor, float], +) -> torch.Tensor: + """Renormalize probabilities by top-p thresholding. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)``. + top_p : Union[torch.Tensor, float] + Top-p threshold in ``(0, 1)``. + + Returns + ------- + torch.Tensor + Renormalized probabilities. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_p_renorm_probs(probs, top_p) + + return _torch_top_p_renorm_probs(probs.float(), top_p).to(probs.dtype) + + +def _torch_top_p_renorm_probs( + probs: torch.Tensor, + top_p: Union[torch.Tensor, float], +) -> torch.Tensor: + """Pure-torch top-p renormalization (operates on float32).""" + sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) + cumsum = torch.cumsum(sorted_probs, dim=-1) + + if isinstance(top_p, (int, float)): + mask = cumsum - sorted_probs > top_p + else: + top_p_t = top_p.unsqueeze(-1) + mask = cumsum - sorted_probs > top_p_t + + sorted_probs[mask] = 0.0 + # scatter back + result = torch.zeros_like(probs) + result.scatter_(1, sorted_indices, sorted_probs) + # renormalize + sums = result.sum(dim=-1, keepdim=True).clamp(min=1e-8) + return result / sums + + +# --------------------------------------------------------------------------- +# top_k_renorm_probs +# --------------------------------------------------------------------------- + + +def top_k_renorm_probs( + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], +) -> torch.Tensor: + """Renormalize probabilities by top-k thresholding. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)``. + top_k : Union[torch.Tensor, int] + Top-k threshold. + + Returns + ------- + torch.Tensor + Renormalized probabilities. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_k_renorm_probs(probs, top_k) + + return _torch_top_k_renorm_probs(probs.float(), top_k).to(probs.dtype) + + +def _torch_top_k_renorm_probs( + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], +) -> torch.Tensor: + """Pure-torch top-k renormalization (operates on float32).""" + if isinstance(top_k, int): + # uniform top_k across batch + topk_vals, _ = torch.topk(probs, top_k, dim=-1) + threshold = topk_vals[:, -1:] # (B, 1) + else: + # per-request top_k: use sorting + sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) + # gather the k-th value for each row + k_indices = (top_k.long() - 1).unsqueeze(-1) # (B, 1) + threshold = sorted_probs.gather(1, k_indices) # (B, 1) + + mask = probs < threshold + filtered = probs.clone() + filtered[mask] = 0.0 + sums = filtered.sum(dim=-1, keepdim=True).clamp(min=1e-8) + return filtered / sums + + +# --------------------------------------------------------------------------- +# top_k_mask_logits +# --------------------------------------------------------------------------- + + +def top_k_mask_logits( + logits: torch.Tensor, + top_k: Union[torch.Tensor, int], +) -> torch.Tensor: + """Mask logits by top-k thresholding (set non-top-k to -inf). + + Parameters + ---------- + logits : torch.Tensor + ``(batch_size, num_classes)``. + top_k : Union[torch.Tensor, int] + Top-k threshold. + + Returns + ------- + torch.Tensor + Masked logits with the same shape and dtype. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_k_mask_logits(logits, top_k) + + if isinstance(top_k, int): + topk_vals, _ = torch.topk(logits, top_k, dim=-1) + threshold = topk_vals[:, -1:] + else: + sorted_logits, _ = torch.sort(logits, dim=-1, descending=True) + k_indices = (top_k.long() - 1).unsqueeze(-1) + threshold = sorted_logits.gather(1, k_indices) + + mask = logits < threshold + result = logits.clone() + result[mask] = float("-inf") + return result + + +# --------------------------------------------------------------------------- +# chain_speculative_sampling +# --------------------------------------------------------------------------- + + +def chain_speculative_sampling( + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + target_probs: torch.Tensor, + maybe_output_accepted_token_num: Optional[torch.Tensor] = None, + maybe_output_emitted_draft_token_num: Optional[torch.Tensor] = None, + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Speculative sampling for sequence generation. + + Parameters + ---------- + draft_probs : torch.Tensor + ``(batch_size, num_speculate_tokens, vocab_size)``. + draft_token_ids : torch.Tensor + ``(batch_size, num_speculate_tokens)``. + target_probs : torch.Tensor + ``(batch_size, num_speculate_tokens + 1, vocab_size)``. + maybe_output_accepted_token_num : Optional[torch.Tensor] + If provided, accepted counts are added in-place. + maybe_output_emitted_draft_token_num : Optional[torch.Tensor] + If provided, emitted counts are added in-place. + deterministic, generator, seed, offset + See FlashInfer docs. + + Returns + ------- + output_token_ids : torch.Tensor + ``(batch_size, num_speculate_tokens + 1)``, rejected slots padded with -1. + output_accepted_token_num : torch.Tensor + ``(batch_size,)``. + output_emitted_draft_token_num : torch.Tensor + ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.chain_speculative_sampling( + draft_probs, + draft_token_ids, + target_probs, + maybe_output_accepted_token_num=maybe_output_accepted_token_num, + maybe_output_emitted_draft_token_num=maybe_output_emitted_draft_token_num, + deterministic=deterministic, + generator=generator, + seed=seed, + offset=offset, + ) + + return _torch_chain_speculative_sampling( + draft_probs, + draft_token_ids, + target_probs, + maybe_output_accepted_token_num, + maybe_output_emitted_draft_token_num, + generator, + ) + + +def _torch_chain_speculative_sampling( + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + target_probs: torch.Tensor, + maybe_output_accepted_token_num: Optional[torch.Tensor], + maybe_output_emitted_draft_token_num: Optional[torch.Tensor], + generator: Optional[torch.Generator], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Pure-torch chain speculative sampling. + + Implements the rejection-sampling algorithm from + "Accelerating Large Language Model Decoding with Speculative Sampling" + (Leviathan et al., 2023). + """ + batch_size, num_spec, vocab_size = draft_probs.shape + device = draft_probs.device + + output_ids = torch.full( + (batch_size, num_spec + 1), -1, dtype=torch.int32, device=device + ) + accepted_count = torch.zeros(batch_size, dtype=torch.int32, device=device) + emitted_count = torch.zeros(batch_size, dtype=torch.int32, device=device) + + for b in range(batch_size): + all_accepted = True + for t in range(num_spec): + draft_tok = draft_token_ids[b, t].item() + p_draft = draft_probs[b, t, draft_tok].item() + p_target = target_probs[b, t, draft_tok].item() + + # independent acceptance check (for the metric) + if p_target >= p_draft: + accepted_count[b] += 1 + else: + r = torch.rand(1, generator=generator, device=device).item() + if r < p_target / max(p_draft, 1e-10): + accepted_count[b] += 1 + + # sequential chain: accept / reject + if all_accepted: + r = torch.rand(1, generator=generator, device=device).item() + if r < min(1.0, p_target / max(p_draft, 1e-10)): + output_ids[b, t] = draft_tok + emitted_count[b] += 1 + else: + # reject: sample from max(0, p_target - p_draft) + diff = target_probs[b, t].float() - draft_probs[b, t].float() + diff = torch.clamp(diff, min=0.0) + dsum = diff.sum() + if dsum > 1e-8: + diff = diff / dsum + else: + diff = target_probs[b, t].float() + diff = diff / diff.sum().clamp(min=1e-8) + resampled = torch.multinomial( + diff.unsqueeze(0), num_samples=1, generator=generator + ).item() + output_ids[b, t] = resampled + emitted_count[b] += 1 + all_accepted = False + + # bonus token (sampled from target at position after last emitted) + if all_accepted: + pos = num_spec + bonus_probs = target_probs[b, pos].float() + bonus_probs = bonus_probs / bonus_probs.sum().clamp(min=1e-8) + bonus = torch.multinomial( + bonus_probs.unsqueeze(0), num_samples=1, generator=generator + ).item() + output_ids[b, num_spec] = bonus + + if maybe_output_accepted_token_num is not None: + maybe_output_accepted_token_num.add_(accepted_count) + if maybe_output_emitted_draft_token_num is not None: + maybe_output_emitted_draft_token_num.add_(emitted_count) + + return output_ids, accepted_count, emitted_count + + +# --------------------------------------------------------------------------- +# Aliases (FlashInfer also exposes these) +# --------------------------------------------------------------------------- +top_p_renorm_prob = top_p_renorm_probs +top_k_renorm_prob = top_k_renorm_probs diff --git a/pymllm/layers/utils.py b/pymllm/layers/utils.py new file mode 100644 index 000000000..0dcbd1ac0 --- /dev/null +++ b/pymllm/layers/utils.py @@ -0,0 +1,45 @@ +"""Utility functions for layers.""" + +from typing import Any, Dict + +import torch + + +def set_weight_attrs( + weight: torch.Tensor, + weight_attrs: Dict[str, Any] | None, +) -> None: + """Set attributes on a weight tensor. + + This method is used to set attributes on a weight tensor. This method + will not overwrite existing attributes. + + Args: + weight: The weight tensor or parameter. + weight_attrs: A dictionary of attributes to set on the weight tensor. + Common attributes include: + - output_dim: The dimension along which to shard the weight (typically 0 for output dim) + - input_dim: The input dimension (typically 1 for input dim) + - weight_loader: A callable to load weights into this parameter + - packed_dim: The dimension along which the weight is packed (for quantization) + - packed_factor: The packing factor (for quantization) + + Example: + >>> weight = nn.Parameter(torch.empty(100, 64)) + >>> set_weight_attrs(weight, { + ... "output_dim": 0, + ... "input_dim": 1, + ... "weight_loader": my_loader_func, + ... }) + """ + if weight_attrs is None: + return + + for key, value in weight_attrs.items(): + if hasattr(weight, key): + raise AttributeError( + f"Overwriting existing tensor attribute: {key}. " + f"Existing value: {getattr(weight, key)}, " + f"New value: {value}" + ) + setattr(weight, key, value) diff --git a/pymllm/mem_cache/__init__.py b/pymllm/mem_cache/__init__.py new file mode 100644 index 000000000..c2ce06eba --- /dev/null +++ b/pymllm/mem_cache/__init__.py @@ -0,0 +1,37 @@ +from pymllm.mem_cache.memory_pool import ( + KVPool, + ReqToTokenPool, + TokenToKVPoolAllocator, + make_full_attention_net_mem_pool, + make_req_to_token_pool, +) +from pymllm.mem_cache.radix_cache import ( + EvictResult, + InsertResult, + MatchResult, + RadixCache, + RadixKey, + TreeNode, + hash_bytes, + hash_to_int64, + hash_token_ids, +) + +__all__ = [ + # memory_pool + "KVPool", + "TokenToKVPoolAllocator", + "ReqToTokenPool", + "make_full_attention_net_mem_pool", + "make_req_to_token_pool", + # radix_cache + "RadixCache", + "RadixKey", + "TreeNode", + "MatchResult", + "InsertResult", + "EvictResult", + "hash_token_ids", + "hash_to_int64", + "hash_bytes", +] diff --git a/pymllm/mem_cache/memory_pool.py b/pymllm/mem_cache/memory_pool.py new file mode 100644 index 000000000..9c8ab2a99 --- /dev/null +++ b/pymllm/mem_cache/memory_pool.py @@ -0,0 +1,639 @@ +"""Lightweight KV-cache memory pools + +Three-layer architecture:: + + ReqToTokenPool maps (req_slot, position) → kv_index + TokenToKVPoolAllocator manages a free-list of integer indices + KVPool holds the actual GPU K/V tensors + +All indices are **int32** tensors on the target device. Slot 0 in the KV +buffers is reserved as a padding / dummy-output slot and is never allocated. +""" + +import logging +from typing import List, Optional, Tuple, Union + +import torch + +from mllm_kernel.cuda.jit.store_cache import store_cache, can_use_store_cache + +logger = logging.getLogger(__name__) + + +class KVPool: + """GPU (or CPU) storage for per-layer key and value caches. + + Layout per layer:: + + JIT: + k_buffer[layer][slot, k_head_num * k_head_dim] + v_buffer[layer][slot, v_head_num * v_head_dim] + + PyTorch: + k_buffer[layer][slot, k_head_num, k_head_dim] + v_buffer[layer][slot, v_head_num, v_head_dim] + + K and V may have **independent** head counts and head dimensions, which + covers standard MHA, GQA / MQA, and architectures like MLA where value + projection uses a different dimensionality. + + ``size`` usable slots are numbered ``[1, size]``. Slot 0 is a dummy + padding slot that absorbs writes from padded tokens. + + Parameters + ---------- + size : int + Number of usable token slots (total buffer length = ``size + 1``). + layer_num : int + Number of transformer layers (one K buffer + one V buffer per layer). + k_head_num : int + Number of key heads. + k_head_dim : int + Dimension of each key head. + device : str | torch.device + Target device (``"cuda"``, ``"cpu"``, …). + dtype : torch.dtype + Storage data type. + v_head_num : int, optional + Number of value heads. Defaults to *k_head_num*. + v_head_dim : int, optional + Dimension of each value head. Defaults to *k_head_dim*. + pin_memory : bool, optional + Whether to use pinned memory. Defaults to True. + """ + + def __init__( + self, + size: int, + layer_num: int, + k_head_num: int, + k_head_dim: int, + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.float16, + v_head_num: Optional[int] = None, + v_head_dim: Optional[int] = None, + pin_memory: bool = True, + ): + self.size = size + self.layer_num = layer_num + self.k_head_num = k_head_num + self.k_head_dim = k_head_dim + self.v_head_num = v_head_num if v_head_num is not None else k_head_num + self.v_head_dim = v_head_dim if v_head_dim is not None else k_head_dim + self.device = torch.device(device) + self.dtype = dtype + + # pin_memory only applies to CPU tensors + if self.device.type != "cpu": + pin_memory = False + + buf_len = size + 1 # slot 0 is padding + + if buf_len % 8 != 0: + logger.warning( + "KVPool buffer length is not divisible by 8, padding to the next multiple of 8" + ) + buf_len = (buf_len + 7) & ~7 + + k_row_dim = self.k_head_num * self.k_head_dim + v_row_dim = self.v_head_num * self.v_head_dim + self._same_kv_dim = k_row_dim == v_row_dim + self._row_bytes = k_row_dim * torch.tensor([], dtype=dtype).element_size() + self._use_jit = ( + self.device.type == "cuda" + and self._same_kv_dim + and can_use_store_cache(self._row_bytes) + ) + if not self._use_jit: + logger.warning( + f"Fallback to PyTorch index for KVPool, which is slower than the mllm-kernel's implementation, same_kv_dim={self._same_kv_dim}, row_bytes={self._row_bytes}" + ) + + self.k_buffer: List[torch.Tensor] = [ + torch.zeros( + (buf_len, self.k_head_num, self.k_head_dim), + dtype=dtype, + device=self.device, + pin_memory=pin_memory, + ) + for _ in range(layer_num) + ] + self.v_buffer: List[torch.Tensor] = [ + torch.zeros( + (buf_len, self.v_head_num, self.v_head_dim), + dtype=dtype, + device=self.device, + pin_memory=pin_memory, + ) + for _ in range(layer_num) + ] + + # Pre-computed 2D views for the JIT store_cache kernel. + # Zero-copy: same underlying storage as k_buffer / v_buffer. + if self._use_jit: + self._k_buffer_2d = [b.view(buf_len, -1) for b in self.k_buffer] + self._v_buffer_2d = [b.view(buf_len, -1) for b in self.v_buffer] + + logger.info( + "KVPool allocated: %d layers, %d slots, K=[%d,%d] V=[%d,%d], %.2f GB", + layer_num, + size, + self.k_head_num, + self.k_head_dim, + self.v_head_num, + self.v_head_dim, + self._mem_bytes() / (1 << 30), + ) + + def get_key_buffer(self, layer_id: int) -> torch.Tensor: + return self.k_buffer[layer_id] + + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + return self.v_buffer[layer_id] + + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + return self.k_buffer[layer_id], self.v_buffer[layer_id] + + def set_kv_buffer( + self, + layer_id: int, + indices: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> None: + """Write K/V vectors into the cache at the given *indices*. + + ``k`` / ``v`` can be any shape as long as the trailing dimensions + multiply to ``head_num * head_dim`` (the row dimension). All leading + dimensions are treated as the batch axis and must match ``indices`` + after flattening. Typical shapes:: + + k: [num_tokens, head_num, head_dim] indices: [num_tokens] + k: [batch, seq_len, head_num, head_dim] indices: [batch, seq_len] + k: [num_tokens, head_num * head_dim] indices: [num_tokens] + """ + if self._use_jit: + row_dim = self.k_head_num * self.k_head_dim + store_cache( + k.reshape(-1, row_dim), + v.reshape(-1, row_dim), + self._k_buffer_2d[layer_id], + self._v_buffer_2d[layer_id], + indices.reshape(-1), + row_bytes=self._row_bytes, + ) + else: + self.k_buffer[layer_id][indices] = k + self.v_buffer[layer_id][indices] = v + + def _mem_bytes(self) -> int: + total = 0 + for buf in self.k_buffer + self.v_buffer: + total += buf.nelement() * buf.element_size() + return total + + +class TokenToKVPoolAllocator: + """Manages allocation / deallocation of integer indices into a :class:`KVPool`. + + Each ``alloc(n)`` returns *n* free indices; each ``free(indices)`` returns + them to the pool. + + Uses a **dual-buffer** strategy (``free_slots`` + ``release_slots``) so + that ``free()`` never cats onto the large main free-list. Freed indices + accumulate in the smaller ``release_slots`` and are merged lazily (with an + optional sort) only when ``alloc()`` cannot be satisfied from + ``free_slots`` alone. + + A **batch-free** API (``free_group_begin`` / ``free_group_end``) further + amortises cost when many ``free()`` calls happen in a tight loop (e.g. + during scheduling or eviction). + + Typical usage:: + + allocator = TokenToKVPoolAllocator(size=4096, device="cuda") + + # --- basic alloc / free --- + indices = allocator.alloc(128) # 128 free slot indices (int32) + allocator.free(indices[:64]) # return 64 slots + + # --- batch free (amortised) --- + allocator.free_group_begin() + for req in finished_requests: + allocator.free(req.kv_indices) # O(1) list append each + allocator.free_group_end() # single torch.cat + release + + Parameters + ---------- + size : int + Total number of allocatable slots (must match ``KVPool.size``). + device : str | torch.device + Device for the free-list tensor. + page_size : int + When > 1 the allocator works in page-aligned mode: ``alloc`` returns + multiples of ``page_size`` contiguous within each page, and ``free`` + deduplicates by page. + need_sort : bool + When ``True`` (default), ``merge_and_sort_free`` sorts after merging + so that lower-index slots are allocated first (better memory locality). + """ + + def __init__( + self, + size: int, + device: Union[str, torch.device] = "cuda", + page_size: int = 1, + need_sort: bool = True, + ): + self.size = size + self.page_size = page_size + self.device = torch.device(device) + self.need_sort = need_sort + self.clear() + + def clear(self) -> None: + """Reset the allocator so that all slots ``[1, size]`` are free. The first slot is reserved for padding.""" + if self.page_size == 1: + self.free_slots = torch.arange( + 1, self.size + 1, dtype=torch.int32, device=self.device + ) + else: + num_pages = self.size // self.page_size + self.free_slots = torch.arange( + 1, num_pages + 1, dtype=torch.int32, device=self.device + ) + self.release_slots = torch.empty((0,), dtype=torch.int32, device=self.device) + self._is_not_in_free_group = True + self._free_group: List[torch.Tensor] = [] + + def available_size(self) -> int: + """Number of tokens that can still be allocated.""" + return (len(self.free_slots) + len(self.release_slots)) * self.page_size + + def merge_and_sort_free(self) -> None: + """Merge ``release_slots`` into ``free_slots`` (and sort if ``need_sort``).""" + if len(self.release_slots) == 0: + return + self.free_slots = torch.cat((self.free_slots, self.release_slots)) + if self.need_sort: + self.free_slots, _ = torch.sort(self.free_slots) + self.release_slots = torch.empty((0,), dtype=torch.int32, device=self.device) + + def free_group_begin(self) -> None: + """Start collecting ``free()`` calls; actual release is deferred to ``free_group_end``.""" + self._is_not_in_free_group = False + self._free_group = [] + + def free_group_end(self) -> None: + """Flush all ``free()`` calls collected since ``free_group_begin``.""" + self._is_not_in_free_group = True + if self._free_group: + self.free(torch.cat(self._free_group)) + self._free_group = [] + + def alloc(self, need_size: int) -> Optional[torch.Tensor]: + """Allocate *need_size* token indices. + + Returns a 1-D ``int32`` tensor on success, or ``None`` if the pool is + exhausted. + """ + if self.page_size == 1: + if need_size > len(self.free_slots): + self.merge_and_sort_free() + if need_size > len(self.free_slots): + return None + out = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] + return out + + num_pages = (need_size + self.page_size - 1) // self.page_size + if num_pages > len(self.free_slots): + self.merge_and_sort_free() + if num_pages > len(self.free_slots): + return None + pages = self.free_slots[:num_pages] + self.free_slots = self.free_slots[num_pages:] + offsets = torch.arange(self.page_size, device=self.device) + out = (pages[:, None] * self.page_size + offsets).reshape(-1) + return out[:need_size] + + def free(self, indices: torch.Tensor) -> None: + """Return *indices* to the free pool.""" + if indices.numel() == 0: + return + + if not self._is_not_in_free_group: + self._free_group.append(indices) + return + + if self.page_size != 1: + indices = torch.unique(indices // self.page_size) + + if self.need_sort: + self.release_slots = torch.cat((self.release_slots, indices)) + else: + self.free_slots = torch.cat((self.free_slots, indices)) + + +class ReqToTokenPool: + """Maps each live request to its per-position KV-pool indices. + + Internally a 2-D tensor ``req_to_token[slot, position]`` stores the + KV-pool index for every token position of every active request. + Slots are recycled via a simple free-list. + + This class is a **pure mapping table** -- it does **not** track per-request + sequence lengths. The caller (typically the ``Req`` / IO-struct object) + must store ``req_pool_idx`` and ``seq_len`` and use them to slice into + ``req_to_token`` when reading back KV indices. + + Typical usage:: + + pool = ReqToTokenPool(max_reqs=256, max_context_len=4096) + + # --- on new request arrival --- + [slot] = pool.alloc(1) # slot = req_pool_idx + kv_indices = kv_allocator.alloc(seq_len) # from TokenToKVPoolAllocator + pool.write((slot, slice(0, seq_len)), kv_indices) + + # --- read back (caller tracks seq_len) --- + kv_indices = pool.req_to_token[slot, :seq_len] + + # --- on request completion --- + kv_allocator.free(pool.req_to_token[slot, :seq_len]) + pool.free(slot) + + Parameters + ---------- + max_reqs : int + Maximum number of concurrent requests (number of rows). + max_context_len : int + Maximum sequence length any single request can reach (number of cols). + device : str | torch.device + Target device for the mapping tensor. + """ + + def __init__( + self, + max_reqs: int, + max_context_len: int, + device: Union[str, torch.device] = "cuda", + ): + self.size = max_reqs + self.max_context_len = max_context_len + self.device = torch.device(device) + + self.req_to_token = torch.zeros( + (max_reqs, max_context_len), dtype=torch.int32, device=self.device + ) + self._free_slots: List[int] = list(range(max_reqs)) + + def available_size(self) -> int: + return len(self._free_slots) + + def alloc(self, n: int = 1) -> Optional[List[int]]: + """Allocate *n* request slots. Returns a list of slot indices.""" + if n > len(self._free_slots): + return None + out = self._free_slots[:n] + self._free_slots = self._free_slots[n:] + return out + + def free(self, slot: int) -> None: + """Return a single request slot to the pool.""" + self._free_slots.append(slot) + + def write(self, index: Tuple, values: torch.Tensor) -> None: + """Write KV indices into the mapping table. + + ``index`` is typically ``(req_pool_idx, slice(start, end))``. + """ + self.req_to_token[index] = values + + def clear(self) -> None: + self._free_slots = list(range(self.size)) + self.req_to_token.zero_() + + +def make_full_attention_net_mem_pool( + size: int, + layer_num: int, + k_head_num: int, + k_head_dim: int, + v_head_num: int, + v_head_dim: int, + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.float16, + page_size: int = 1, + need_sort: bool = True, + pin_memory: bool = True, +) -> Tuple[KVPool, TokenToKVPoolAllocator]: + """Create a :class:`KVPool` and its :class:`TokenToKVPoolAllocator` for a + full-attention (non-SWA) model. + + Parameters + ---------- + size : int + Number of usable token slots in the KV cache. + layer_num : int + Number of transformer layers. + k_head_num / k_head_dim : int + Key head count and dimension. + v_head_num / v_head_dim : int + Value head count and dimension. + device : str | torch.device + Target device. + dtype : torch.dtype + Storage data type for the KV buffers. + page_size : int + Allocator page size (1 = per-token, >1 = page-aligned). + need_sort : bool + Whether the allocator sorts on merge for memory locality. + pin_memory : bool + Whether to use pinned memory for the KV buffers. + + Returns + ------- + (KVPool, TokenToKVPoolAllocator) + """ + pool = KVPool( + size=size, + layer_num=layer_num, + k_head_num=k_head_num, + k_head_dim=k_head_dim, + device=device, + dtype=dtype, + v_head_num=v_head_num, + v_head_dim=v_head_dim, + pin_memory=pin_memory, + ) + allocator = TokenToKVPoolAllocator( + size=size, + device=device, + page_size=page_size, + need_sort=need_sort, + ) + return pool, allocator + + +class GDNPool: + """Pre-allocated memory pool for GDN recurrent and conv states. + + Indexed by ``req_pool_idx`` (same index space as :class:`ReqToTokenPool`). + Slot 0 is reserved as a padding / dummy slot and is never allocated. + + Layout:: + + recurrent_state[gdn_layer_idx, slot, num_v_heads, head_k_dim, head_v_dim] + float32 (FlashInfer requirement) + conv_state[gdn_layer_idx, slot, conv_dim, kernel_size - 1] + model dtype (bfloat16 / float16) + + Parameters + ---------- + max_reqs : int + Maximum number of concurrent requests (matches ``ReqToTokenPool.size``). + num_gdn_layers : int + Number of GDN (linear attention) layers in the model. + num_v_heads : int + Number of value heads per GDN layer. + head_k_dim : int + Per-head key dimension. + head_v_dim : int + Per-head value dimension. + conv_dim : int + Total convolution input dimension (``key_dim * 2 + value_dim``). + conv_kernel_size : int + Causal conv1d kernel width (state stores ``kernel_size - 1`` columns). + device : str | torch.device + Target device. + dtype : torch.dtype + Storage dtype for conv_state (recurrent_state is always float32). + """ + + def __init__( + self, + max_reqs: int, + num_gdn_layers: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + conv_dim: int, + conv_kernel_size: int, + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + max_track_slots: int = 0, + ): + self.max_reqs = max_reqs + self.num_gdn_layers = num_gdn_layers + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + self.conv_dim = conv_dim + self.conv_kernel_size = conv_kernel_size + self.device = torch.device(device) + self.dtype = dtype + self.max_track_slots = max_track_slots + + # Track slots live after the working slots: indices + # [max_reqs + 1, max_reqs + 1 + max_track_slots) + pool_size = max_reqs + 1 + max_track_slots # slot 0 is padding + + # Recurrent state: always float32 (FlashInfer requirement) + # Shape: [num_gdn_layers, pool_size, num_v_heads, head_v_dim, head_k_dim] + # Note: FlashInfer uses (V, K) layout for the state matrix + self.recurrent_state = torch.zeros( + (num_gdn_layers, pool_size, num_v_heads, head_v_dim, head_k_dim), + dtype=torch.float32, + device=self.device, + ) + + # Conv state: model dtype + # Shape: [num_gdn_layers, pool_size, conv_dim, kernel_size - 1] + self.conv_state = torch.zeros( + (num_gdn_layers, pool_size, conv_dim, conv_kernel_size - 1), + dtype=dtype, + device=self.device, + ) + + # Track-slot free list (indices into the pool starting after working slots) + self._track_slot_base = max_reqs + 1 + self._free_track_slots: List[int] = list( + range(self._track_slot_base, self._track_slot_base + max_track_slots) + ) + + logger.info( + "GDNPool allocated: %d GDN layers, %d working + %d track slots, " + "v_heads=%d, k_dim=%d, v_dim=%d, conv_dim=%d, kernel=%d, %.2f GB", + num_gdn_layers, + max_reqs, + max_track_slots, + num_v_heads, + head_k_dim, + head_v_dim, + conv_dim, + conv_kernel_size, + self.mem_bytes() / (1 << 30), + ) + + def get_layer_state( + self, gdn_layer_idx: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Return ``(recurrent_state, conv_state)`` for a specific GDN layer. + + Both are views into the pool tensors with shape: + - recurrent: ``[pool_size, num_v_heads, head_v_dim, head_k_dim]`` + - conv: ``[pool_size, conv_dim, kernel_size - 1]`` + """ + return ( + self.recurrent_state[gdn_layer_idx], + self.conv_state[gdn_layer_idx], + ) + + def reset_states(self, req_pool_indices: torch.Tensor) -> None: + """Zero-init GDN states for the given request pool indices. + + Called when new requests are allocated to ensure clean state. + """ + if req_pool_indices.numel() == 0: + return + # Zero both recurrent and conv states for all GDN layers + self.recurrent_state[:, req_pool_indices] = 0 + self.conv_state[:, req_pool_indices] = 0 + + # ------------------------------------------------------------------ + # Track-slot management (for prefix cache GDN state snapshots) + # ------------------------------------------------------------------ + + def alloc_track_slot(self) -> Optional[int]: + """Allocate a single track slot index. Returns ``None`` if exhausted.""" + if not self._free_track_slots: + return None + return self._free_track_slots.pop() + + def free_track_slot(self, slot: int) -> None: + """Return a track slot to the free list.""" + self._free_track_slots.append(slot) + + def copy_states(self, src_index: int, dst_index: int) -> None: + """Copy recurrent and conv states from *src_index* to *dst_index*. + + Works for any pool indices (working or track slots). + """ + self.recurrent_state[:, dst_index] = self.recurrent_state[:, src_index] + self.conv_state[:, dst_index] = self.conv_state[:, src_index] + + def mem_bytes(self) -> int: + """Total memory consumption in bytes.""" + return ( + self.recurrent_state.nelement() * self.recurrent_state.element_size() + + self.conv_state.nelement() * self.conv_state.element_size() + ) + + +def make_req_to_token_pool( + max_reqs: int, + max_context_len: int, + device: Union[str, torch.device] = "cuda", +) -> ReqToTokenPool: + return ReqToTokenPool(max_reqs, max_context_len, device) diff --git a/pymllm/mem_cache/radix_cache.py b/pymllm/mem_cache/radix_cache.py new file mode 100644 index 000000000..441a8c097 --- /dev/null +++ b/pymllm/mem_cache/radix_cache.py @@ -0,0 +1,808 @@ +"""Lightweight radix-tree KV cache with SWA and multimodal support. + + +Supports: + - Multi-batch serving on a single GPU + - Sliding Window Attention (SWA) via tombstone mechanism + - Multimodal namespace isolation via ``extra_key`` + - SHA256 position-aware hashing + - Page-aligned operations (page_size >= 1) + - LRU leaf eviction +""" + +from __future__ import annotations + +import hashlib +import heapq +import logging +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union + +import torch + +logger = logging.getLogger(__name__) + + +def hash_token_ids( + token_ids: List[Union[int, Tuple[int, ...]]], + prior_hash: Optional[str] = None, +) -> str: + """SHA-256 hash of a token-id page with optional chain-hash. + + Each token is encoded as a 4-byte little-endian unsigned integer; + tuples (bigram / EAGLE) hash each element in order. When *prior_hash* + is supplied the digest is seeded with the raw bytes of the previous + hash, making the result position-aware. + """ + hasher = hashlib.sha256() + if prior_hash: + hasher.update(bytes.fromhex(prior_hash)) + for t in token_ids: + if isinstance(t, tuple): + for elem in t: + hasher.update(elem.to_bytes(4, byteorder="little", signed=False)) + else: + hasher.update(t.to_bytes(4, byteorder="little", signed=False)) + return hasher.hexdigest() + + +def hash_to_int64(hex_str: str) -> int: + """Convert a hex digest to a signed 64-bit integer (first 16 hex chars).""" + val = int(hex_str[:16], 16) + return val - (1 << 64) if val >= (1 << 63) else val + + +def hash_bytes(data: bytes) -> int: + """SHA-256 → unsigned 64-bit int. Useful for multimodal embedding keys.""" + return int.from_bytes(hashlib.sha256(data).digest()[:8], "big", signed=False) + + +class RadixKey: + """Compound lookup key: token-id sequence + optional namespace tag. + + ``extra_key`` isolates independent namespaces so that sequences with + identical leading tokens but different adapters / LoRA ids / multimodal + context hashes never share prefix nodes. + """ + + __slots__ = ("token_ids", "extra_key") + + def __init__( + self, + token_ids: List[Union[int, Tuple[int, ...]]], + extra_key: Optional[str] = None, + ): + self.token_ids = token_ids + self.extra_key = extra_key + + def __len__(self) -> int: + return len(self.token_ids) + + def __iter__(self) -> Iterator: + return iter(self.token_ids) + + def __getitem__(self, idx: Union[int, slice]) -> RadixKey: + if isinstance(idx, slice): + return RadixKey(self.token_ids[idx], self.extra_key) + return RadixKey([self.token_ids[idx]], self.extra_key) + + def __repr__(self) -> str: + preview = self.token_ids[:10] + tail = "..." if len(self.token_ids) > 10 else "" + return f"RadixKey(extra={self.extra_key!r}, toks={preview}{tail})" + + +_node_counter: int = 0 + + +def _next_node_id() -> int: + global _node_counter + _node_counter += 1 + return _node_counter + + +class TreeNode: + """A single node in the radix tree. + + ``value`` holds a 1-D ``int64`` tensor of KV-pool indices (one per token + in ``key``). When the node has been evicted, ``value`` is ``None``. + """ + + __slots__ = ( + "children", + "parent", + "key", + "value", + "lock_ref", + "swa_lock_ref", + "swa_tombstone", + "swa_boundary_id", + "last_access_time", + "hit_count", + "hash_values", + "id", + ) + + def __init__(self) -> None: + self.children: Dict[Any, TreeNode] = defaultdict(TreeNode) + self.parent: Optional[TreeNode] = None + self.key: Optional[RadixKey] = None + self.value: Optional[torch.Tensor] = None + + self.lock_ref: int = 0 + self.swa_lock_ref: int = 0 + self.swa_tombstone: bool = False + self.swa_boundary_id: Optional[int] = None + + self.last_access_time: float = time.monotonic() + self.hit_count: int = 0 + self.hash_values: Optional[List[str]] = None + self.id: int = _next_node_id() + + @property + def evicted(self) -> bool: + return self.value is None + + def __lt__(self, other: TreeNode) -> bool: + return self.last_access_time < other.last_access_time + + +def _key_match(key0: RadixKey, key1: RadixKey, page_size: int) -> int: + """Return the length of the common prefix (page-aligned when *page_size* > 1).""" + if key0.extra_key != key1.extra_key: + return 0 + if page_size == 1: + i = 0 + for a, b in zip(key0.token_ids, key1.token_ids): + if a != b: + break + i += 1 + return i + min_len = min(len(key0), len(key1)) + i = 0 + while i < min_len: + if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]: + break + i += page_size + return i + + +def _child_key(key: RadixKey, page_size: int) -> Any: + """Derive the dict key used in ``node.children``.""" + plain = key.token_ids[0] if page_size == 1 else tuple(key.token_ids[:page_size]) + return (key.extra_key, plain) if key.extra_key is not None else plain + + +@dataclass +class MatchResult: + """Returned by :meth:`RadixCache.match_prefix`.""" + + indices: torch.Tensor + last_node: TreeNode + prefix_len: int = 0 + + +@dataclass +class InsertResult: + """Returned by :meth:`RadixCache.insert`.""" + + prefix_len: int = 0 + last_node: Optional[TreeNode] = None + + +@dataclass +class EvictResult: + """Returned by :meth:`RadixCache.evict`.""" + + full_evicted: int = 0 + swa_evicted: int = 0 + + +class RadixCache: + """Lightweight radix tree for KV-cache prefix sharing. + + Parameters + ---------- + page_size: + Number of tokens per KV-pool page. Keys and values are aligned to + this granularity. + sliding_window_size: + If set, enables SWA mode. The cache tracks which nodes have had + their SWA KV freed (tombstoned) and constrains prefix matching + so that the sliding-window invariant is maintained. + disable: + When *True* every public method is a no-op (useful for ablation). + token_to_kv_pool_allocator: + Optional pool allocator with ``free(indices)`` (and ``free_swa`` for + SWA mode). When *None*, index tensors are simply discarded. + """ + + def __init__( + self, + page_size: int = 1, + sliding_window_size: Optional[int] = None, + disable: bool = False, + token_to_kv_pool_allocator: Any = None, + on_node_evict: Optional[Callable[[int], None]] = None, + ): + self.page_size = page_size + self.sliding_window_size = sliding_window_size + self.disable = disable + self.pool = token_to_kv_pool_allocator + self.on_node_evict = on_node_evict + + if self.pool is not None and hasattr(self.pool, "device"): + self.device = self.pool.device + else: + self.device = torch.device("cpu") + + self._swa_boundary_counter: int = 0 + self.reset() + + @property + def supports_swa(self) -> bool: + return self.sliding_window_size is not None + + def evictable_size(self) -> int: + return self._evictable_size + + def swa_evictable_size(self) -> int: + return self._swa_evictable_size + + def protected_size(self) -> int: + return self._protected_size + + def swa_protected_size(self) -> int: + return self._swa_protected_size + + def reset(self) -> None: + """Clear all cached state and re-initialise the root node.""" + self.root_node = TreeNode() + self.root_node.key = RadixKey([]) + self.root_node.value = torch.tensor([], dtype=torch.int64) + self.root_node.lock_ref = 1 + self.root_node.swa_lock_ref = 1 + self._evictable_size: int = 0 + self._swa_evictable_size: int = 0 + self._protected_size: int = 0 + self._swa_protected_size: int = 0 + + def match_prefix(self, key: RadixKey) -> MatchResult: + """Find the longest cached prefix of *key*. + + For SWA mode the match is further constrained: the path from the + returned ``last_node`` to root must have at least + ``sliding_window_size`` non-tombstone tokens (or be entirely + tombstone-free back to root). + + Accessing a prefix refreshes LRU timestamps along the matched path. + """ + empty = MatchResult( + indices=torch.empty(0, dtype=torch.int64, device=self.device), + last_node=self.root_node, + ) + if self.disable or len(key) == 0: + return empty + + key = self._page_align_key(key) + if len(key) == 0: + return empty + + if self.supports_swa: + values, last_node, best_count = self._match_swa(key) + values = values[:best_count] + else: + values, last_node = self._match_normal(key) + + cat = ( + torch.cat(values) + if values + else torch.empty(0, dtype=torch.int64, device=self.device) + ) + return MatchResult(indices=cat, last_node=last_node, prefix_len=len(cat)) + + def insert( + self, + key: RadixKey, + value: Optional[torch.Tensor] = None, + *, + prev_prefix_len: int = 0, + swa_evicted_seqlen: int = 0, + ) -> InsertResult: + """Insert *key*/*value* into the tree. + + Returns how many leading tokens were already present (the prefix + length). The caller is responsible for freeing duplicate KV indices + in the range ``[cache_protected_len, prefix_len)``. + + Parameters + ---------- + prev_prefix_len: + (SWA mode) tokens before this offset are already protected and + should not have their values overwritten. + swa_evicted_seqlen: + (SWA mode) the sequence length up to which SWA KV has been + previously evicted. Used to decide whether a tombstoned node can + be un-tombstoned with the incoming value. + """ + if self.disable: + return InsertResult() + if value is None: + value = torch.tensor(key.token_ids, dtype=torch.int64) + if self.supports_swa: + plen = self._insert_swa( + self.root_node, key, value, prev_prefix_len, swa_evicted_seqlen + ) + return InsertResult(prefix_len=plen) + else: + plen, last_node = self._insert_normal(self.root_node, key, value) + return InsertResult(prefix_len=plen, last_node=last_node) + + def evict(self, num_tokens: int, swa_num_tokens: int = 0) -> EvictResult: + """Evict up to *num_tokens* (full) and *swa_num_tokens* (SWA) tokens. + + Full eviction removes leaf nodes entirely; SWA eviction tombstones + internal nodes (freeing SWA KV but retaining full-attn KV). + """ + if self.disable: + return EvictResult() + + full_evicted = 0 + swa_evicted = 0 + + # Phase 1: full leaf eviction + if num_tokens > 0: + leaves = self._collect_evictable_leaves() + heap: List[Tuple[float, TreeNode]] = [ + (n.last_access_time, n) for n in leaves + ] + heapq.heapify(heap) + + while full_evicted < num_tokens and heap: + _, node = heapq.heappop(heap) + if node.evicted or node.lock_ref > 0: + continue + n = len(node.value) + self._free_indices(node.value) + full_evicted += n + swa_evicted += n + self._delete_leaf(node) + + p = node.parent + if ( + p is not None + and p != self.root_node + and len(p.children) == 0 + and p.lock_ref == 0 + ): + if self.supports_swa and p.swa_tombstone: + self._free_indices(p.value) + full_evicted += len(p.value) + self._delete_leaf(p) + else: + heapq.heappush(heap, (p.last_access_time, p)) + + # Phase 2: SWA tombstone eviction (internal nodes) + if self.supports_swa and swa_evicted < swa_num_tokens: + candidates = self._collect_swa_evictable() + heap2: List[Tuple[float, TreeNode]] = [ + (n.last_access_time, n) for n in candidates + ] + heapq.heapify(heap2) + + while swa_evicted < swa_num_tokens and heap2: + _, node = heapq.heappop(heap2) + if node.swa_tombstone or node.swa_lock_ref > 0 or node.evicted: + continue + n = len(node.value) + if len(node.children) == 0 and node.lock_ref == 0: + self._free_indices(node.value) + full_evicted += n + swa_evicted += n + self._delete_leaf(node) + elif len(node.children) > 0: + self._free_swa_indices(node.value) + swa_evicted += n + self._tombstone_node(node) + + return EvictResult(full_evicted=full_evicted, swa_evicted=swa_evicted) + + def inc_lock_ref(self, node: TreeNode) -> Optional[int]: + """Lock nodes from *node* up to root (prevents eviction). + + Returns ``swa_boundary_id`` that must be passed back to + :meth:`dec_lock_ref`. In non-SWA mode, returns ``None``. + """ + if self.disable or node is None: + return None + + swa_locked = 0 + swa_boundary_id: Optional[int] = None + cur = node + while cur != self.root_node: + if cur.lock_ref == 0: + self._evictable_size -= len(cur.key) + self._protected_size += len(cur.key) + cur.lock_ref += 1 + + if ( + self.supports_swa + and swa_locked < self.sliding_window_size + and not cur.swa_tombstone + ): + if cur.swa_lock_ref == 0: + self._swa_evictable_size -= len(cur.key) + self._swa_protected_size += len(cur.key) + cur.swa_lock_ref += 1 + swa_locked += len(cur.key) + if swa_locked >= self.sliding_window_size: + if cur.swa_boundary_id is None: + self._swa_boundary_counter += 1 + cur.swa_boundary_id = self._swa_boundary_counter + swa_boundary_id = cur.swa_boundary_id + + cur = cur.parent + return swa_boundary_id + + def dec_lock_ref( + self, node: TreeNode, swa_boundary_id: Optional[int] = None + ) -> None: + """Unlock nodes from *node* up to root.""" + if self.disable or node is None: + return + + dec_swa = True + cur = node + while cur != self.root_node: + if cur.lock_ref == 1: + self._evictable_size += len(cur.key) + self._protected_size -= len(cur.key) + cur.lock_ref -= 1 + + if self.supports_swa and dec_swa and not cur.swa_tombstone: + if cur.swa_lock_ref == 1: + self._swa_evictable_size += len(cur.key) + self._swa_protected_size -= len(cur.key) + cur.swa_lock_ref -= 1 + if swa_boundary_id and cur.swa_boundary_id == swa_boundary_id: + dec_swa = False + + cur = cur.parent + + def total_size(self) -> int: + """Total number of cached tokens (including tombstoned).""" + total = 0 + stack: List[TreeNode] = [self.root_node] + while stack: + n = stack.pop() + if n.value is not None: + total += len(n.value) + stack.extend(c for c in n.children.values() if not c.evicted) + return total + + def compute_node_hash(self, node: TreeNode) -> List[str]: + """Compute position-aware SHA-256 hashes for *node* (one per page). + + Lazily computed and cached on ``node.hash_values``. + """ + if node.hash_values is not None: + return node.hash_values + + parent_hash: Optional[str] = None + if ( + node.parent is not None + and node.parent.hash_values is not None + and len(node.parent.key) > 0 + and len(node.parent.hash_values) > 0 + ): + parent_hash = node.parent.hash_values[-1] + + hashes: List[str] = [] + for start in range(0, len(node.key), self.page_size): + page = node.key.token_ids[start : start + self.page_size] + if not page: + continue + h = hash_token_ids(page, prior_hash=parent_hash) + hashes.append(h) + parent_hash = h + + node.hash_values = hashes + return hashes + + def pretty_print(self) -> None: + """Print the tree structure to stdout.""" + self._print_helper(self.root_node, 0) + print( + f"total={self.total_size()} evictable={self._evictable_size}" + + ( + f" swa_evictable={self._swa_evictable_size}" + if self.supports_swa + else "" + ) + ) + + def _match_normal(self, key: RadixKey) -> Tuple[List[torch.Tensor], TreeNode]: + node = self.root_node + now = time.monotonic() + node.last_access_time = now + values: List[torch.Tensor] = [] + + while len(key) > 0: + ck = _child_key(key, self.page_size) + if ck not in node.children: + break + child = node.children[ck] + child.last_access_time = now + child.hit_count += 1 + plen = _key_match(child.key, key, self.page_size) + if plen < len(child.key): + new_node = self._split_node(child.key, child, plen) + values.append(new_node.value) + node = new_node + break + values.append(child.value) + node = child + key = key[plen:] + + return values, node + + def _match_swa(self, key: RadixKey) -> Tuple[List[torch.Tensor], TreeNode, int]: + """SWA-aware match. Returns *(values, last_node, best_value_count)*. + + ``best_value_count`` is the number of value tensors from *values* + that form a valid SWA-safe prefix (enough non-tombstone tokens within + the sliding window, or a tombstone-free path to root). + """ + node = self.root_node + values: List[torch.Tensor] = [] + non_tomb_len: float = float("inf") + best_count = 0 + best_node = node + + while len(key) > 0: + ck = _child_key(key, self.page_size) + if ck not in node.children: + break + child = node.children[ck] + + if child.swa_tombstone: + if non_tomb_len >= self.sliding_window_size: + best_count = len(values) + best_node = node + non_tomb_len = 0 + + plen = _key_match(child.key, key, self.page_size) + if plen < len(child.key): + new_node = self._split_node(child.key, child, plen) + values.append(new_node.value) + if not new_node.swa_tombstone: + non_tomb_len += len(new_node.value) + node = new_node + break + values.append(child.value) + if not child.swa_tombstone: + non_tomb_len += len(child.value) + node = child + key = key[plen:] + + if non_tomb_len >= self.sliding_window_size: + best_count = len(values) + best_node = node + + return values, best_node, best_count + + def _insert_normal( + self, node: TreeNode, key: RadixKey, value: torch.Tensor + ) -> Tuple[int, TreeNode]: + """Insert into non-SWA tree. Returns ``(prefix_len, last_node)``.""" + now = time.monotonic() + node.last_access_time = now + if len(key) == 0: + return 0, node + + total_prefix = 0 + ck = _child_key(key, self.page_size) + while len(key) > 0 and ck in node.children: + node = node.children[ck] + node.last_access_time = now + plen = _key_match(node.key, key, self.page_size) + total_prefix += plen + key = key[plen:] + value = value[plen:] + + if plen < len(node.key): + # Partial match: split the node. ``node`` must advance to + # the NEW parent so that any remaining key is added as a + # sibling of the tail, not a child of it. + node = self._split_node(node.key, node, plen) + if len(key) > 0: + ck = _child_key(key, self.page_size) + + if len(key) > 0: + new_leaf = self._add_leaf(node, key, value) + node = new_leaf + + return total_prefix, node + + def _insert_swa( + self, + node: TreeNode, + key: RadixKey, + value: torch.Tensor, + prev_prefix_len: int, + swa_evicted_seqlen: int, + ) -> int: + """Insert with SWA tombstone awareness. + + When an existing node is tombstoned and the incoming *value* carries + fresh SWA KV (i.e. beyond *swa_evicted_seqlen*), the node is + un-tombstoned and its value is replaced. + """ + now = time.monotonic() + node.last_access_time = now + if len(key) == 0: + return 0 + + total_prefix = 0 + while len(key) > 0: + ck = _child_key(key, self.page_size) + if ck not in node.children: + break + node = node.children[ck] + node.last_access_time = now + plen = _key_match(node.key, key, self.page_size) + + if plen < len(node.key): + self._split_node(node.key, node, plen) + + beyond_protected = prev_prefix_len < total_prefix + plen + if beyond_protected and node.swa_tombstone: + if swa_evicted_seqlen <= total_prefix: + self._free_indices(node.value[:plen]) + node.value = value[:plen].clone() + node.swa_tombstone = False + self._swa_evictable_size += len(node.value) + else: + self._free_indices(value[:plen]) + elif beyond_protected: + self._free_indices(value[:plen]) + + total_prefix += plen + key = key[plen:] + value = value[plen:] + + if len(key) > 0: + if ( + swa_evicted_seqlen > total_prefix + and swa_evicted_seqlen < total_prefix + len(key) + ): + tomb_len = swa_evicted_seqlen - total_prefix + self._add_leaf( + node, key[:tomb_len], value[:tomb_len], swa_tombstone=True + ) + node = node.children[_child_key(key, self.page_size)] + key = key[tomb_len:] + value = value[tomb_len:] + + if len(key) > 0: + self._add_leaf(node, key, value, swa_tombstone=False) + + return total_prefix + + def _add_leaf( + self, + parent: TreeNode, + key: RadixKey, + value: torch.Tensor, + swa_tombstone: bool = False, + ) -> TreeNode: + new_node = TreeNode() + new_node.parent = parent + new_node.key = key + new_node.value = value.clone() + new_node.swa_tombstone = swa_tombstone + parent.children[_child_key(key, self.page_size)] = new_node + self._evictable_size += len(key) + if self.supports_swa and not swa_tombstone: + self._swa_evictable_size += len(key) + return new_node + + def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode: + """Split *child* at *split_len*, returning the new parent node.""" + new_node = TreeNode() + new_node.children[_child_key(key[split_len:], self.page_size)] = child + new_node.parent = child.parent + new_node.lock_ref = child.lock_ref + new_node.swa_lock_ref = child.swa_lock_ref + new_node.swa_tombstone = child.swa_tombstone + new_node.swa_boundary_id = child.swa_boundary_id + child.swa_boundary_id = None + new_node.key = child.key[:split_len] + new_node.value = child.value[:split_len].clone() + + # Split hash values if they exist + if child.hash_values is not None: + pages = split_len // self.page_size if self.page_size > 1 else split_len + new_node.hash_values = child.hash_values[:pages] + child.hash_values = child.hash_values[pages:] + else: + new_node.hash_values = None + + child.parent = new_node + child.key = child.key[split_len:] + child.value = child.value[split_len:].clone() + new_node.parent.children[_child_key(key, self.page_size)] = new_node + return new_node + + def _delete_leaf(self, node: TreeNode) -> None: + ck = _child_key(node.key, self.page_size) + node.parent.children.pop(ck, None) + self._evictable_size -= len(node.key) + if self.supports_swa and not node.swa_tombstone: + self._swa_evictable_size -= len(node.key) + if self.on_node_evict is not None: + self.on_node_evict(node.id) + + def _tombstone_node(self, node: TreeNode) -> None: + node.swa_tombstone = True + self._swa_evictable_size -= len(node.key) + + def _collect_evictable_leaves(self) -> List[TreeNode]: + leaves: List[TreeNode] = [] + stack: List[TreeNode] = [self.root_node] + while stack: + n = stack.pop() + if n.evicted: + continue + has_live_child = False + for c in n.children.values(): + if not c.evicted: + has_live_child = True + stack.append(c) + if not has_live_child and n.lock_ref == 0 and n != self.root_node: + leaves.append(n) + return leaves + + def _collect_swa_evictable(self) -> List[TreeNode]: + nodes: List[TreeNode] = [] + stack: List[TreeNode] = [self.root_node] + while stack: + n = stack.pop() + if n.evicted: + continue + if n != self.root_node and not n.swa_tombstone and n.swa_lock_ref == 0: + nodes.append(n) + stack.extend(c for c in n.children.values() if not c.evicted) + return nodes + + def _page_align_key(self, key: RadixKey) -> RadixKey: + if self.page_size == 1: + return key + aligned = len(key) // self.page_size * self.page_size + return key[:aligned] + + def _free_indices(self, indices: torch.Tensor) -> None: + if self.pool is not None and len(indices) > 0: + self.pool.free(indices) + + def _free_swa_indices(self, indices: torch.Tensor) -> None: + if self.pool is not None and len(indices) > 0: + if hasattr(self.pool, "free_swa"): + self.pool.free_swa(indices) + else: + self.pool.free(indices) + + def _print_helper(self, node: TreeNode, indent: int) -> None: + stack = [(node, indent)] + while stack: + n, ind = stack.pop() + toks = n.key.token_ids[:10] if n.key else [] + klen = len(n.key) if n.key else 0 + flags = f"lock={n.lock_ref}" + if self.supports_swa: + flags += f" swa={n.swa_lock_ref} tomb={n.swa_tombstone}" + print(f"{' ' * ind}[{klen}] {toks} {flags}") + for c in n.children.values(): + stack.append((c, ind + 1)) diff --git a/pymllm/mobile/README.md b/pymllm/mobile/README.md index 29877ea00..ceb71a5d3 100644 --- a/pymllm/mobile/README.md +++ b/pymllm/mobile/README.md @@ -1 +1,2 @@ -We should refactor current pymllm's src to mobile directory. And provide more functionalities for torch based VLA. +# Pymllm mobile + diff --git a/pymllm/mobile/__init__.py b/pymllm/mobile/__init__.py new file mode 100644 index 000000000..8796bbeaf --- /dev/null +++ b/pymllm/mobile/__init__.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from . import ffi +from . import convertor +from . import utils +from . import quantize +from . import nn +from . import service +from . import backends +from .ffi import ( + # Floating point types + float32, + float16, + bfloat16, + # Signed integer types + int8, + int16, + int32, + int64, + # Unsigned integer types + uint8, + uint16, + uint32, + uint64, + # Bool type + boolean, + # Devices + cpu, + cuda, + qnn, + # Tensor and utilities + Tensor, + empty, + echo, + device, + is_torch_available, + is_numpy_available, + from_torch, + from_numpy, + zeros, + ones, + arange, + random, +) +from .nn.functional import matmul diff --git a/pymllm/quantize/spinquant/__init__.py b/pymllm/mobile/backends/__init__.py similarity index 71% rename from pymllm/quantize/spinquant/__init__.py rename to pymllm/mobile/backends/__init__.py index ea8e2bec7..1578a0d87 100644 --- a/pymllm/quantize/spinquant/__init__.py +++ b/pymllm/mobile/backends/__init__.py @@ -1,2 +1,4 @@ # Copyright (c) MLLM Team. # Licensed under the MIT License. + +from . import qualcomm diff --git a/pymllm/backends/qualcomm/README.md b/pymllm/mobile/backends/qualcomm/README.md similarity index 100% rename from pymllm/backends/qualcomm/README.md rename to pymllm/mobile/backends/qualcomm/README.md diff --git a/pymllm/backends/qualcomm/__init__.py b/pymllm/mobile/backends/qualcomm/__init__.py similarity index 100% rename from pymllm/backends/qualcomm/__init__.py rename to pymllm/mobile/backends/qualcomm/__init__.py diff --git a/pymllm/backends/qualcomm/nn.py b/pymllm/mobile/backends/qualcomm/nn.py similarity index 75% rename from pymllm/backends/qualcomm/nn.py rename to pymllm/mobile/backends/qualcomm/nn.py index 0ba9aef55..e4bc91ace 100644 --- a/pymllm/backends/qualcomm/nn.py +++ b/pymllm/mobile/backends/qualcomm/nn.py @@ -1,4 +1,4 @@ -from pymllm.nn._layers import Softmax, RoPE +from pymllm.mobile.nn._layers import Softmax, RoPE class QnnSoftmax(Softmax): diff --git a/pymllm/backends/qualcomm/qnn_aot_env.py b/pymllm/mobile/backends/qualcomm/qnn_aot_env.py similarity index 83% rename from pymllm/backends/qualcomm/qnn_aot_env.py rename to pymllm/mobile/backends/qualcomm/qnn_aot_env.py index 8b0c0d2e1..bc48c7c97 100644 --- a/pymllm/backends/qualcomm/qnn_aot_env.py +++ b/pymllm/mobile/backends/qualcomm/qnn_aot_env.py @@ -1,7 +1,7 @@ -from pymllm.ffi import is_qnn_aot_on_x86_enabled +from pymllm.mobile.ffi import is_qnn_aot_on_x86_enabled if is_qnn_aot_on_x86_enabled(): - from pymllm.ffi import ( + from pymllm.mobile.ffi import ( QnnDeviceAndContext, QnnAOTEnv, QcomChipset, diff --git a/pymllm/backends/qualcomm/transformers/.gitignore b/pymllm/mobile/backends/qualcomm/transformers/.gitignore similarity index 100% rename from pymllm/backends/qualcomm/transformers/.gitignore rename to pymllm/mobile/backends/qualcomm/transformers/.gitignore diff --git a/pymllm/backends/qualcomm/transformers/README.md b/pymllm/mobile/backends/qualcomm/transformers/README.md similarity index 100% rename from pymllm/backends/qualcomm/transformers/README.md rename to pymllm/mobile/backends/qualcomm/transformers/README.md diff --git a/pymllm/backends/qualcomm/transformers/__init__.py b/pymllm/mobile/backends/qualcomm/transformers/__init__.py similarity index 100% rename from pymllm/backends/qualcomm/transformers/__init__.py rename to pymllm/mobile/backends/qualcomm/transformers/__init__.py diff --git a/pymllm/compile/mlir/__init__.py b/pymllm/mobile/backends/qualcomm/transformers/core/__init__.py similarity index 100% rename from pymllm/compile/mlir/__init__.py rename to pymllm/mobile/backends/qualcomm/transformers/core/__init__.py diff --git a/pymllm/backends/qualcomm/transformers/core/embedding.py b/pymllm/mobile/backends/qualcomm/transformers/core/embedding.py similarity index 100% rename from pymllm/backends/qualcomm/transformers/core/embedding.py rename to pymllm/mobile/backends/qualcomm/transformers/core/embedding.py diff --git a/pymllm/backends/qualcomm/transformers/core/observer.py b/pymllm/mobile/backends/qualcomm/transformers/core/observer.py similarity index 100% rename from pymllm/backends/qualcomm/transformers/core/observer.py rename to pymllm/mobile/backends/qualcomm/transformers/core/observer.py diff --git a/pymllm/backends/qualcomm/transformers/core/qdq.py b/pymllm/mobile/backends/qualcomm/transformers/core/qdq.py similarity index 100% rename from pymllm/backends/qualcomm/transformers/core/qdq.py rename to pymllm/mobile/backends/qualcomm/transformers/core/qdq.py diff --git a/pymllm/backends/qualcomm/transformers/core/qlinear.py b/pymllm/mobile/backends/qualcomm/transformers/core/qlinear.py similarity index 99% rename from pymllm/backends/qualcomm/transformers/core/qlinear.py rename to pymllm/mobile/backends/qualcomm/transformers/core/qlinear.py index 9e90ba8a5..35439180c 100644 --- a/pymllm/backends/qualcomm/transformers/core/qlinear.py +++ b/pymllm/mobile/backends/qualcomm/transformers/core/qlinear.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.ao.quantization import FakeQuantize, PerChannelMinMaxObserver -from pymllm.backends.qualcomm.transformers.core.observer import ( +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ( PerBlockParamFakeQuantize, ) from torchao.quantization.quant_primitives import ( diff --git a/pymllm/backends/qualcomm/transformers/core/rms_norm.py b/pymllm/mobile/backends/qualcomm/transformers/core/rms_norm.py similarity index 100% rename from pymllm/backends/qualcomm/transformers/core/rms_norm.py rename to pymllm/mobile/backends/qualcomm/transformers/core/rms_norm.py diff --git a/pymllm/backends/qualcomm/transformers/llama/modeling_llama.py b/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py similarity index 98% rename from pymllm/backends/qualcomm/transformers/llama/modeling_llama.py rename to pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py index 119ec04bc..6b65f34b9 100644 --- a/pymllm/backends/qualcomm/transformers/llama/modeling_llama.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py @@ -52,16 +52,16 @@ from transformers.models.llama.configuration_llama import LlamaConfig # Replace linear, rms_norm with: -from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.backends.qualcomm.transformers.core.qlinear import ( +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, ) -from pymllm.backends.qualcomm.transformers.core.qdq import ( +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( ActivationQDQ, FixedActivationQDQ, ) -from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver logger = logging.get_logger(__name__) diff --git a/pymllm/backends/qualcomm/transformers/llama/runner.py b/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py similarity index 96% rename from pymllm/backends/qualcomm/transformers/llama/runner.py rename to pymllm/mobile/backends/qualcomm/transformers/llama/runner.py index 8aa4627bf..730147d0f 100644 --- a/pymllm/backends/qualcomm/transformers/llama/runner.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py @@ -2,18 +2,18 @@ from tqdm import tqdm from modelscope.msdatasets import MsDataset from transformers import AutoTokenizer -from pymllm.backends.qualcomm.transformers.core.qdq import ( +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( ActivationQDQ, FixedActivationQDQ, ) -from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.backends.qualcomm.transformers.core.qlinear import ( +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, QLinearW8A16_PerChannelSym, ) -from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.backends.qualcomm.transformers.llama.modeling_llama import LlamaForCausalLM -from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.llama.modeling_llama import LlamaForCausalLM +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver def recompute_scale_zp(module): diff --git a/pymllm/backends/qualcomm/transformers/llama/train.py b/pymllm/mobile/backends/qualcomm/transformers/llama/train.py similarity index 94% rename from pymllm/backends/qualcomm/transformers/llama/train.py rename to pymllm/mobile/backends/qualcomm/transformers/llama/train.py index cd10befba..41ffc0e27 100644 --- a/pymllm/backends/qualcomm/transformers/llama/train.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/train.py @@ -2,7 +2,7 @@ import torch import argparse from safetensors.torch import save_model -from pymllm.backends.qualcomm.transformers.llama.runner import LlamaQuantizer +from pymllm.mobile.backends.qualcomm.transformers.llama.runner import LlamaQuantizer def main(): diff --git a/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py b/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py similarity index 98% rename from pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py rename to pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py index 56b19c421..a43d8b7ea 100644 --- a/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py @@ -31,16 +31,16 @@ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config # Replace linear, rms_norm with: -from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.backends.qualcomm.transformers.core.qlinear import ( +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, ) -from pymllm.backends.qualcomm.transformers.core.qdq import ( +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( ActivationQDQ, FixedActivationQDQ, ) -from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver class Qwen2MLP(nn.Module): diff --git a/pymllm/backends/qualcomm/transformers/qwen2/runner.py b/pymllm/mobile/backends/qualcomm/transformers/qwen2/runner.py similarity index 96% rename from pymllm/backends/qualcomm/transformers/qwen2/runner.py rename to pymllm/mobile/backends/qualcomm/transformers/qwen2/runner.py index d2f5be05b..ce55fd06d 100644 --- a/pymllm/backends/qualcomm/transformers/qwen2/runner.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen2/runner.py @@ -2,18 +2,18 @@ from tqdm import tqdm from modelscope.msdatasets import MsDataset from transformers import AutoTokenizer -from pymllm.backends.qualcomm.transformers.core.qdq import ( +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( ActivationQDQ, FixedActivationQDQ, ) -from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.backends.qualcomm.transformers.core.qlinear import ( +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, QLinearW8A16_PerChannelSym, ) -from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.backends.qualcomm.transformers.qwen2.modeling_qwen2 import Qwen2ForCausalLM -from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.qwen2.modeling_qwen2 import Qwen2ForCausalLM +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver def recompute_scale_zp(module): diff --git a/pymllm/backends/qualcomm/transformers/qwen2/train.py b/pymllm/mobile/backends/qualcomm/transformers/qwen2/train.py similarity index 94% rename from pymllm/backends/qualcomm/transformers/qwen2/train.py rename to pymllm/mobile/backends/qualcomm/transformers/qwen2/train.py index fec5fdfca..1a8f25ce9 100644 --- a/pymllm/backends/qualcomm/transformers/qwen2/train.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen2/train.py @@ -2,7 +2,7 @@ import torch import argparse from safetensors.torch import save_model -from pymllm.backends.qualcomm.transformers.qwen2.runner import Qwen2Quantizer +from pymllm.mobile.backends.qualcomm.transformers.qwen2.runner import Qwen2Quantizer def main(): diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/mobile/backends/qualcomm/transformers/qwen3/modeling_qwen3.py similarity index 98% rename from pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py rename to pymllm/mobile/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index 2dabf5c9c..6a8788bad 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -46,16 +46,16 @@ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config # Replace linear, rms_norm with: -from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.backends.qualcomm.transformers.core.qlinear import ( +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, ) -from pymllm.backends.qualcomm.transformers.core.qdq import ( +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( ActivationQDQ, FixedActivationQDQ, ) -from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver class Qwen3MLP(nn.Module): diff --git a/pymllm/backends/qualcomm/transformers/qwen3/runner.py b/pymllm/mobile/backends/qualcomm/transformers/qwen3/runner.py similarity index 96% rename from pymllm/backends/qualcomm/transformers/qwen3/runner.py rename to pymllm/mobile/backends/qualcomm/transformers/qwen3/runner.py index 02ea6a5f0..0d7499c96 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/runner.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen3/runner.py @@ -2,18 +2,18 @@ from tqdm import tqdm from modelscope.msdatasets import MsDataset from transformers import AutoTokenizer -from pymllm.backends.qualcomm.transformers.core.qdq import ( +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( ActivationQDQ, FixedActivationQDQ, ) -from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.backends.qualcomm.transformers.core.qlinear import ( +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, QLinearW8A16_PerChannelSym, ) -from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.backends.qualcomm.transformers.qwen3.modeling_qwen3 import Qwen3ForCausalLM -from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.qwen3.modeling_qwen3 import Qwen3ForCausalLM +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver def recompute_scale_zp(module): diff --git a/pymllm/backends/qualcomm/transformers/qwen3/train.py b/pymllm/mobile/backends/qualcomm/transformers/qwen3/train.py similarity index 94% rename from pymllm/backends/qualcomm/transformers/qwen3/train.py rename to pymllm/mobile/backends/qualcomm/transformers/qwen3/train.py index 63c6d0e86..f44fa67b5 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/train.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen3/train.py @@ -2,7 +2,7 @@ import torch import argparse from safetensors.torch import save_model -from pymllm.backends.qualcomm.transformers.qwen3.runner import Qwen3Quantizer +from pymllm.mobile.backends.qualcomm.transformers.qwen3.runner import Qwen3Quantizer def main(): diff --git a/pymllm/convertor/__init__.py b/pymllm/mobile/convertor/__init__.py similarity index 100% rename from pymllm/convertor/__init__.py rename to pymllm/mobile/convertor/__init__.py diff --git a/pymllm/convertor/mllm_type_mapping.py b/pymllm/mobile/convertor/mllm_type_mapping.py similarity index 100% rename from pymllm/convertor/mllm_type_mapping.py rename to pymllm/mobile/convertor/mllm_type_mapping.py diff --git a/pymllm/convertor/model_file_v1.py b/pymllm/mobile/convertor/model_file_v1.py similarity index 100% rename from pymllm/convertor/model_file_v1.py rename to pymllm/mobile/convertor/model_file_v1.py diff --git a/pymllm/convertor/model_file_v2.py b/pymllm/mobile/convertor/model_file_v2.py similarity index 100% rename from pymllm/convertor/model_file_v2.py rename to pymllm/mobile/convertor/model_file_v2.py diff --git a/pymllm/ffi/__init__.py b/pymllm/mobile/ffi/__init__.py similarity index 100% rename from pymllm/ffi/__init__.py rename to pymllm/mobile/ffi/__init__.py diff --git a/pymllm/ffi/_ffi_api.py b/pymllm/mobile/ffi/_ffi_api.py similarity index 100% rename from pymllm/ffi/_ffi_api.py rename to pymllm/mobile/ffi/_ffi_api.py diff --git a/pymllm/ffi/base.py b/pymllm/mobile/ffi/base.py similarity index 90% rename from pymllm/ffi/base.py rename to pymllm/mobile/ffi/base.py index 07a01c49e..96aed2425 100644 --- a/pymllm/ffi/base.py +++ b/pymllm/mobile/ffi/base.py @@ -8,7 +8,7 @@ def _load_lib(): file_dir = os.path.dirname(os.path.realpath(__file__)) - parent_dir = os.path.dirname(file_dir) + parent_dir = os.path.dirname(os.path.dirname(file_dir)) # Platform-specific library names if sys.platform.startswith("win32"): diff --git a/pymllm/nn/__init__.py b/pymllm/mobile/nn/__init__.py similarity index 100% rename from pymllm/nn/__init__.py rename to pymllm/mobile/nn/__init__.py diff --git a/pymllm/nn/_layers.py b/pymllm/mobile/nn/_layers.py similarity index 100% rename from pymllm/nn/_layers.py rename to pymllm/mobile/nn/_layers.py diff --git a/pymllm/nn/_module.py b/pymllm/mobile/nn/_module.py similarity index 100% rename from pymllm/nn/_module.py rename to pymllm/mobile/nn/_module.py diff --git a/pymllm/nn/functional.py b/pymllm/mobile/nn/functional.py similarity index 100% rename from pymllm/nn/functional.py rename to pymllm/mobile/nn/functional.py diff --git a/pymllm/quantize/__init__.py b/pymllm/mobile/quantize/__init__.py similarity index 100% rename from pymllm/quantize/__init__.py rename to pymllm/mobile/quantize/__init__.py diff --git a/pymllm/quantize/cast2fp32_pass.py b/pymllm/mobile/quantize/cast2fp32_pass.py similarity index 100% rename from pymllm/quantize/cast2fp32_pass.py rename to pymllm/mobile/quantize/cast2fp32_pass.py diff --git a/pymllm/compile/__init__.py b/pymllm/mobile/quantize/gguf/__init__.py similarity index 100% rename from pymllm/compile/__init__.py rename to pymllm/mobile/quantize/gguf/__init__.py diff --git a/pymllm/quantize/kai/__init__.py b/pymllm/mobile/quantize/kai/__init__.py similarity index 100% rename from pymllm/quantize/kai/__init__.py rename to pymllm/mobile/quantize/kai/__init__.py diff --git a/pymllm/quantize/kai/w4a32.py b/pymllm/mobile/quantize/kai/w4a32.py similarity index 100% rename from pymllm/quantize/kai/w4a32.py rename to pymllm/mobile/quantize/kai/w4a32.py diff --git a/pymllm/quantize/pipeline.py b/pymllm/mobile/quantize/pipeline.py similarity index 100% rename from pymllm/quantize/pipeline.py rename to pymllm/mobile/quantize/pipeline.py diff --git a/pymllm/quantize/quantize_pass.py b/pymllm/mobile/quantize/quantize_pass.py similarity index 100% rename from pymllm/quantize/quantize_pass.py rename to pymllm/mobile/quantize/quantize_pass.py diff --git a/pymllm/quantize/solver.py b/pymllm/mobile/quantize/solver.py similarity index 100% rename from pymllm/quantize/solver.py rename to pymllm/mobile/quantize/solver.py diff --git a/pymllm/quantize/gguf/__init__.py b/pymllm/mobile/quantize/spinquant/__init__.py similarity index 100% rename from pymllm/quantize/gguf/__init__.py rename to pymllm/mobile/quantize/spinquant/__init__.py diff --git a/pymllm/service/__init__.py b/pymllm/mobile/service/__init__.py similarity index 100% rename from pymllm/service/__init__.py rename to pymllm/mobile/service/__init__.py diff --git a/pymllm/service/models_hub.py b/pymllm/mobile/service/models_hub.py similarity index 100% rename from pymllm/service/models_hub.py rename to pymllm/mobile/service/models_hub.py diff --git a/pymllm/service/network.py b/pymllm/mobile/service/network.py similarity index 100% rename from pymllm/service/network.py rename to pymllm/mobile/service/network.py diff --git a/pymllm/service/rr_process.py b/pymllm/mobile/service/rr_process.py similarity index 100% rename from pymllm/service/rr_process.py rename to pymllm/mobile/service/rr_process.py diff --git a/pymllm/service/tools.py b/pymllm/mobile/service/tools.py similarity index 100% rename from pymllm/service/tools.py rename to pymllm/mobile/service/tools.py diff --git a/pymllm/tests/qualcomm/test_context_create.py b/pymllm/mobile/tests/qualcomm/test_context_create.py similarity index 89% rename from pymllm/tests/qualcomm/test_context_create.py rename to pymllm/mobile/tests/qualcomm/test_context_create.py index 18983daa7..94f42b513 100644 --- a/pymllm/tests/qualcomm/test_context_create.py +++ b/pymllm/mobile/tests/qualcomm/test_context_create.py @@ -1,5 +1,5 @@ -import pymllm as mllm -from pymllm.backends.qualcomm.qnn_aot_env import ( +import pymllm.mobile as mllm +from pymllm.mobile.backends.qualcomm.qnn_aot_env import ( QnnAOTEnv, QnnDeviceAndContext, QcomTryBestPerformance, diff --git a/pymllm/tests/test_nn.py b/pymllm/mobile/tests/test_nn.py similarity index 83% rename from pymllm/tests/test_nn.py rename to pymllm/mobile/tests/test_nn.py index d9a3db2d8..403060e99 100644 --- a/pymllm/tests/test_nn.py +++ b/pymllm/mobile/tests/test_nn.py @@ -1,5 +1,5 @@ -import pymllm as mllm -from pymllm import nn +import pymllm.mobile as mllm +from pymllm.mobile import nn class FooModule(nn.Module): diff --git a/pymllm/tests/test_tensor.py b/pymllm/mobile/tests/test_tensor.py similarity index 89% rename from pymllm/tests/test_tensor.py rename to pymllm/mobile/tests/test_tensor.py index e935f10b4..474e10922 100644 --- a/pymllm/tests/test_tensor.py +++ b/pymllm/mobile/tests/test_tensor.py @@ -1,7 +1,7 @@ # Copyright (c) MLLM Team. # Licensed under the MIT License. -import pymllm as torch +import pymllm.mobile as torch def test_empty_tensor_create() -> bool: diff --git a/pymllm/utils/__init__.py b/pymllm/mobile/utils/__init__.py similarity index 100% rename from pymllm/utils/__init__.py rename to pymllm/mobile/utils/__init__.py diff --git a/pymllm/utils/adb.py b/pymllm/mobile/utils/adb.py similarity index 100% rename from pymllm/utils/adb.py rename to pymllm/mobile/utils/adb.py diff --git a/pymllm/utils/error_handler.py b/pymllm/mobile/utils/error_handler.py similarity index 100% rename from pymllm/utils/error_handler.py rename to pymllm/mobile/utils/error_handler.py diff --git a/pymllm/utils/mllm_convertor.py b/pymllm/mobile/utils/mllm_convertor.py similarity index 100% rename from pymllm/utils/mllm_convertor.py rename to pymllm/mobile/utils/mllm_convertor.py diff --git a/pymllm/models/__init__.py b/pymllm/models/__init__.py new file mode 100644 index 000000000..7751b3091 --- /dev/null +++ b/pymllm/models/__init__.py @@ -0,0 +1,62 @@ +"""Model registry for pymllm. + +Maps HuggingFace ``config.architectures[0]`` strings to pymllm model classes. +Models are imported lazily via ``importlib`` so that heavy dependencies (torch, +numpy, etc.) are only loaded when a model is actually requested. +""" + +from __future__ import annotations + +import importlib +import logging +from typing import Dict, Optional, Tuple, Type + +import torch.nn as nn + +logger = logging.getLogger(__name__) + +# (module_path, class_name) +_MODEL_REGISTRY: Dict[str, Tuple[str, str]] = { + "Qwen3VLForConditionalGeneration": ( + "pymllm.models.qwen3_vl", + "Qwen3VLForConditionalGeneration", + ), + # Qwen3.5 (hybrid attention: full + GDN linear) + "Qwen3_5ForCausalLM": ( + "pymllm.models.qwen3_5", + "Qwen3_5ForCausalLM", + ), + "Qwen3_5ForConditionalGeneration": ( + "pymllm.models.qwen3_5", + "Qwen3_5ForConditionalGeneration", + ), +} + + +def get_model_class(architecture: str) -> Optional[Type[nn.Module]]: + """Look up a pymllm model class by HuggingFace architecture string. + + Returns ``None`` if the architecture is not registered or cannot be + imported. The caller is responsible for raising an appropriate error. + """ + entry = _MODEL_REGISTRY.get(architecture) + if entry is None: + return None + + module_path, class_name = entry + try: + module = importlib.import_module(module_path) + cls = getattr(module, class_name) + logger.info( + "Resolved architecture %r -> %s.%s", architecture, module_path, class_name + ) + return cls + except (ImportError, AttributeError) as exc: + logger.warning( + "Failed to import %s.%s for architecture %r: %s", + module_path, + class_name, + architecture, + exc, + ) + return None diff --git a/pymllm/models/qwen3_5.py b/pymllm/models/qwen3_5.py new file mode 100644 index 000000000..ca4dbe2ea --- /dev/null +++ b/pymllm/models/qwen3_5.py @@ -0,0 +1,530 @@ +"""Inference-only Qwen3.5 model for pymllm. + +Implements the hybrid attention architecture: +- **Full attention layers** (standard transformer with RoPE + output gate) +- **GDN linear attention layers** (Gated Delta Network, O(n) complexity) + +Layers alternate: linear, attention, linear, attention, ... based on +``full_attention_interval`` in the config. + +Supports: +- Dense (non-MoE) variant +- Vision-Language (multimodal) via inheritance from Qwen3VL + +Adapted from sglang's ``qwen3_5.py``. +""" + +from __future__ import annotations + +import logging +import math +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pymllm.layers.attention.radix_attention import RadixAttention +from pymllm.layers.embedding import VocabParallelEmbedding +from pymllm.layers.gated_delta_net import GatedDeltaNet +from pymllm.layers.linear import Linear +from pymllm.layers.mlp import MLP +from pymllm.layers.rms_norm import GemmaRMSNorm, RMSNorm +from pymllm.layers.rope import apply_rope_pos_ids +from pymllm.layers.utils import set_weight_attrs + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Config helpers +# --------------------------------------------------------------------------- + + +def _get_text_config(config): + """Extract the text sub-config from a multimodal config, or return as-is.""" + return getattr(config, "text_config", config) + + +def _get_layer_types(config) -> List[str]: + """Return per-layer type list: 'attention' or 'linear_attention'.""" + if hasattr(config, "layers_block_type"): + return config.layers_block_type + # Compute from full_attention_interval + interval = getattr(config, "full_attention_interval", 2) + n_layers = config.num_hidden_layers + types = [] + for i in range(n_layers): + if (i + 1) % interval == 0: + types.append("attention") + else: + types.append("linear_attention") + return types + + +# --------------------------------------------------------------------------- +# Full Attention Layer (with output gate + QK norm) +# --------------------------------------------------------------------------- + + +class Qwen3_5FullAttention(nn.Module): + """Standard multi-head attention with RoPE, QK-norm, and optional output gate.""" + + def __init__(self, config, layer_id: int): + super().__init__() + tc = _get_text_config(config) + self.hidden_size = tc.hidden_size + self.num_heads = tc.num_attention_heads + self.num_kv_heads = tc.num_key_value_heads + self.head_dim = getattr(tc, "head_dim", self.hidden_size // self.num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim ** -0.5 + self.layer_id = layer_id + + # Output gate: Qwen3.5 doubles the Q projection and uses half as a + # sigmoid gate on the attention output. + self.attn_output_gate = getattr(tc, "attn_output_gate", True) + + if self.attn_output_gate: + q_proj_size = self.q_size * 2 # Q + gate + else: + q_proj_size = self.q_size + + self.q_proj = Linear(self.hidden_size, q_proj_size, bias=False) + self.k_proj = Linear(self.hidden_size, self.kv_size, bias=False) + self.v_proj = Linear(self.hidden_size, self.kv_size, bias=False) + self.o_proj = Linear(self.q_size, self.hidden_size, bias=False) + + # QK normalization + self.q_norm = GemmaRMSNorm(self.head_dim, eps=tc.rms_norm_eps) + self.k_norm = GemmaRMSNorm(self.head_dim, eps=tc.rms_norm_eps) + + # RoPE config + self.partial_rotary_factor = getattr(tc, "partial_rotary_factor", 1.0) + rope_config = getattr(tc, "rope_parameters", None) or getattr(tc, "rope_scaling", None) or {} + self.rope_theta = rope_config.get("rope_theta", getattr(tc, "rope_theta", 10000.0)) + self.rotary_dim = int(self.head_dim * self.partial_rotary_factor) + + # RadixAttention layer — delegates to the pluggable attention backend + self.attn = RadixAttention( + num_heads=self.num_heads, + head_dim=self.head_dim, + scaling=self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: Any, + ) -> torch.Tensor: + seq_len = hidden_states.shape[0] + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + if self.attn_output_gate: + # Split Q into actual Q and gate + q_gate = q.view(seq_len, self.num_heads, self.head_dim * 2) + q, gate = q_gate.chunk(2, dim=-1) + q = q.reshape(seq_len, -1) + gate = gate.reshape(seq_len, -1) + + # QK norm + q = self.q_norm(q.reshape(-1, self.head_dim)).view(seq_len, -1) + k = self.k_norm(k.reshape(-1, self.head_dim)).view(seq_len, -1) + + # RoPE (inplace; rotary_dim handles partial rotation) + q = q.view(seq_len, self.num_heads, self.head_dim) + k = k.view(seq_len, self.num_kv_heads, self.head_dim) + apply_rope_pos_ids( + q, k, positions, inplace=True, + rotary_dim=self.rotary_dim, rope_theta=self.rope_theta, + ) + q = q.reshape(seq_len, -1) + k = k.reshape(seq_len, -1) + + # Standard attention via RadixAttention → attn_backend + attn_output = self.attn(q, k, v, forward_batch) + + # Output gate + if self.attn_output_gate: + attn_output = attn_output * torch.sigmoid(gate) + + return self.o_proj(attn_output) + + +# --------------------------------------------------------------------------- +# Full Attention Decoder Layer +# --------------------------------------------------------------------------- + + +class Qwen3_5AttentionDecoderLayer(nn.Module): + """Decoder layer with full attention + MLP.""" + + def __init__(self, config, layer_id: int): + super().__init__() + tc = _get_text_config(config) + self.self_attn = Qwen3_5FullAttention(config, layer_id) + self.mlp = MLP( + hidden_size=tc.hidden_size, + intermediate_size=tc.intermediate_size, + activation=tc.hidden_act, + ) + self.input_layernorm = GemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + forward_batch: Any, + ): + # Pre-norm + residual + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn(positions, hidden_states, forward_batch) + + # Post-attention norm + residual + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +# --------------------------------------------------------------------------- +# Linear Attention (GDN) Decoder Layer +# --------------------------------------------------------------------------- + + +class Qwen3_5LinearDecoderLayer(nn.Module): + """Decoder layer with GDN linear attention + MLP.""" + + def __init__(self, config, layer_id: int, gdn_layer_idx: int = 0): + super().__init__() + tc = _get_text_config(config) + self.linear_attn = GatedDeltaNet( + hidden_size=tc.hidden_size, + num_k_heads=getattr(tc, "linear_num_key_heads", 16), + num_v_heads=getattr(tc, "linear_num_value_heads", 32), + head_k_dim=getattr(tc, "linear_key_head_dim", 128), + head_v_dim=getattr(tc, "linear_value_head_dim", 128), + conv_kernel_size=getattr(tc, "linear_conv_kernel_dim", 4), + layer_id=layer_id, + gdn_layer_idx=gdn_layer_idx, + rms_norm_eps=tc.rms_norm_eps, + ) + self.mlp = MLP( + hidden_size=tc.hidden_size, + intermediate_size=tc.intermediate_size, + activation=tc.hidden_act, + ) + self.input_layernorm = GemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + forward_batch: Any, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.linear_attn(hidden_states, forward_batch) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +# --------------------------------------------------------------------------- +# Layer type registry +# --------------------------------------------------------------------------- + +_DECODER_LAYER_TYPES = { + "attention": Qwen3_5AttentionDecoderLayer, + "linear_attention": Qwen3_5LinearDecoderLayer, +} + + +# --------------------------------------------------------------------------- +# Qwen3.5 Language Model (dense variant) +# --------------------------------------------------------------------------- + + +class Qwen3_5ForCausalLM(nn.Module): + """Qwen3.5 causal language model with hybrid attention. + + Alternates between full attention and GDN linear attention layers. + Dense (non-MoE) variant. + """ + + def __init__(self, config): + super().__init__() + tc = _get_text_config(config) + self.config = tc + self.hidden_size = tc.hidden_size + self.vocab_size = tc.vocab_size + + # Embedding + self.embed_tokens = VocabParallelEmbedding(tc.vocab_size, tc.hidden_size) + + # Build hybrid decoder layers with sequential GDN indexing + layer_types = _get_layer_types(tc) + self.layer_types = layer_types + self.layers = nn.ModuleList() + gdn_count = 0 + self.full_attn_layer_ids = set() + for idx in range(tc.num_hidden_layers): + layer_type = layer_types[idx] + if layer_type == "linear_attention": + self.layers.append( + Qwen3_5LinearDecoderLayer(config, idx, gdn_layer_idx=gdn_count) + ) + gdn_count += 1 + else: + self.layers.append( + Qwen3_5AttentionDecoderLayer(config, idx) + ) + self.full_attn_layer_ids.add(idx) + self.num_gdn_layers = gdn_count + + # Final norm + self.norm = GemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + + logger.info( + "Qwen3_5ForCausalLM: %d layers (%d attention + %d GDN)", + tc.num_hidden_layers, + len(self.full_attn_layer_ids), + self.num_gdn_layers, + ) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: Any, + input_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + ) + + # Final normalization + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) + + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load HuggingFace checkpoint weights with name remapping.""" + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded: Set[str] = set() + + for name, weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "mtp" in name: + continue + if "visual" in name: + continue + if "language_model" in name: + name = name.replace("model.language_model.", "") + if name.startswith("model."): + name = name[len("model."):] + # NOTE: do NOT strip .self_attn — pymllm keeps it as a submodule + + # Handle stacked params (gate_up_proj = gate_proj + up_proj) + matched = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + # gate_up_proj is a plain Linear — manually place each shard + output_dim = param.shape[0] // 2 + param.data[shard_id * output_dim : (shard_id + 1) * output_dim].copy_( + weight + ) + matched = True + break + + if not matched: + if name not in params_dict: + continue + param = params_dict[name] + loader = getattr(param, "weight_loader", None) + if loader is not None: + loader(param, weight) + else: + # Squeeze conv1d weight from [C, 1, K] to [C, K] + if weight.dim() != param.dim(): + weight = weight.squeeze() + param.data.copy_(weight) + + loaded.add(name) + + logger.info("Loaded %d parameter tensors for Qwen3_5ForCausalLM", len(loaded)) + return loaded + + +# --------------------------------------------------------------------------- +# Qwen3.5 Vision-Language Model +# --------------------------------------------------------------------------- + + +class Qwen3_5ForConditionalGeneration(nn.Module): + """Qwen3.5 multimodal model (text + vision). + + Inherits vision encoder from Qwen3VL and uses Qwen3.5's hybrid + language model. + """ + + def __init__(self, config): + super().__init__() + from pymllm.models.qwen3_vl import ( + Qwen3VLVisionModel, + ) + + self.config = config + tc = _get_text_config(config) + + # Vision encoder (reuse Qwen3VL's vision model) + vision_config = getattr(config, "vision_config", None) + if vision_config is not None: + self.visual = Qwen3VLVisionModel( + depth=getattr(vision_config, "depth", 27), + hidden_size=getattr(vision_config, "hidden_size", 1152), + hidden_act=getattr(vision_config, "hidden_act", "gelu_pytorch_tanh"), + intermediate_size=getattr(vision_config, "intermediate_size", 4304), + num_heads=getattr(vision_config, "num_heads", 16), + in_channels=getattr(vision_config, "in_channels", 3), + patch_size=getattr(vision_config, "patch_size", 16), + spatial_merge_size=getattr(vision_config, "spatial_merge_size", 2), + temporal_patch_size=getattr(vision_config, "temporal_patch_size", 2), + out_hidden_size=getattr(vision_config, "out_hidden_size", 3584), + num_position_embeddings=getattr( + vision_config, "num_position_embeddings", 2304 + ), + deepstack_visual_indexes=getattr( + vision_config, "deepstack_visual_indexes", [8, 16, 24] + ), + norm_eps=getattr(tc, "rms_norm_eps", 1e-6), + ) + else: + self.visual = None + + # Language model + self.model = Qwen3_5ForCausalLM(config) + + # Expose hybrid model metadata for ModelRunner + self.num_gdn_layers = self.model.num_gdn_layers + self.full_attn_layer_ids = self.model.full_attn_layer_ids + + # LM head (tied to embedding when tie_word_embeddings=True) + self.lm_head = Linear(tc.hidden_size, tc.vocab_size, bias=False) + if getattr(tc, "tie_word_embeddings", False): + self.lm_head.weight = self.model.embed_tokens.weight + + # Vision token IDs + self.image_token_id = getattr(config, "image_token_id", 151655) + self.video_token_id = getattr(config, "video_token_id", 151656) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: Any, + input_embeds: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Process vision inputs if provided + if input_embeds is None and pixel_values is not None and self.visual is not None: + input_embeds = self.model.embed_tokens(input_ids) + # Run vision encoder + visual_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + # Replace image/video token positions with visual embeddings + mask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id) + if mask.any(): + input_embeds[mask] = visual_embeds.reshape(-1, visual_embeds.shape[-1]) + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + input_embeds=input_embeds, + ) + + # LM head + logits = self.lm_head(hidden_states) + return logits + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights, dispatching visual vs language params.""" + visual_weights = [] + language_weights = [] + + for name, weight in weights: + if "visual" in name or "model.visual" in name: + # Normalize visual weight names + name = name.replace("model.visual.", "visual.") + name = name.replace("attn.qkv.", "attn.qkv_proj.") + visual_weights.append((name, weight)) + else: + language_weights.append((name, weight)) + + # Load language model weights + self.model.load_weights(language_weights) + + # Load visual weights + if self.visual is not None and visual_weights: + params_dict = dict(self.named_parameters()) + for name, weight in visual_weights: + if name in params_dict: + param = params_dict[name] + loader = getattr(param, "weight_loader", None) + if loader is not None: + loader(param, weight) + else: + param.data.copy_(weight) + + logger.info("Qwen3_5ForConditionalGeneration weights loaded") diff --git a/pymllm/compile/mllm_ir/trace.py b/pymllm/models/qwen3_moe.py similarity index 100% rename from pymllm/compile/mllm_ir/trace.py rename to pymllm/models/qwen3_moe.py diff --git a/pymllm/models/qwen3_vl.py b/pymllm/models/qwen3_vl.py new file mode 100644 index 000000000..3bee27c8d --- /dev/null +++ b/pymllm/models/qwen3_vl.py @@ -0,0 +1,1329 @@ +# Copyright 2025 Qwen Team +# Copyright 2025 SGLang Team +# Adapted for pymllm +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Inference-only Qwen3-VL model for pymllm. + +Adapted from sglang's Qwen3-VL implementation for pymllm's single-GPU +inference architecture. Uses pymllm layers (RadixAttention, RMSNorm, MLP) +and conforms to the pymllm forward interface:: + + model.forward(input_ids, positions, forward_batch) + +Designed for a single accelerator card — no tensor / pipeline parallelism. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pymllm.layers import RMSNorm, apply_mrope +from pymllm.layers.attention.radix_attention import RadixAttention +from pymllm.layers.mlp import MLP + +if TYPE_CHECKING: + from pymllm.engine.forward_batch import ForwardBatch + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Vision Encoder +# --------------------------------------------------------------------------- + + +class Qwen3VisionMLP(nn.Module): + """MLP block for the vision encoder.""" + + def __init__( + self, + in_features: int, + hidden_features: int, + hidden_act: str = "silu", + bias: bool = True, + ): + super().__init__() + self.linear_fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.linear_fc2 = nn.Linear(hidden_features, in_features, bias=bias) + if hidden_act == "gelu_pytorch_tanh": + self.act = nn.GELU(approximate="tanh") + elif hidden_act == "gelu": + self.act = nn.GELU() + else: + self.act = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_fc2(self.act(self.linear_fc1(x))) + + +class Qwen3VLVisionPatchEmbed(nn.Module): + """3D convolution patch embedding for video/image patchification.""" + + def __init__( + self, + patch_size: int = 16, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ): + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=True, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( + -1, self.embed_dim + ) + return hidden_states + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dims of the input for RoPE.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class Qwen3VisionAttention(nn.Module): + """Multi-head self-attention for the vision encoder (no KV cache).""" + + def __init__(self, embed_dim: int, num_heads: int): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + ) -> torch.Tensor: + """Forward pass with variable-length sequences via cu_seqlens. + + Args: + x: [total_tokens, embed_dim] + cu_seqlens: [num_seqs + 1] cumulative sequence lengths + rotary_pos_emb_cos: [total_tokens, rotary_dim] + rotary_pos_emb_sin: [total_tokens, rotary_dim] + """ + seq_len = x.shape[0] + qkv = self.qkv_proj(x) + q, k, v = qkv.reshape(seq_len, 3, self.num_heads, self.head_dim).unbind(dim=1) + + # Apply rotary position embedding. + # cos/sin are [total_tokens, head_dim // 2]. Following sglang's + # VisionAttention: double them to full head_dim and apply RoPE to + # all head dimensions (the rotation pairs (q[i], q[i + head_dim//2])). + cos = rotary_pos_emb_cos + sin = rotary_pos_emb_sin + if cos.shape[-1] * 2 == self.head_dim: + cos = torch.cat([cos, cos], dim=-1) + sin = torch.cat([sin, sin], dim=-1) + + cos = cos.unsqueeze(1) # [seq, 1, head_dim] + sin = sin.unsqueeze(1) # [seq, 1, head_dim] + + q = q * cos + _rotate_half(q) * sin + k = k * cos + _rotate_half(k) * sin + + # Scaled dot-product attention per variable-length sequence + output = torch.empty_like(q) + num_seqs = cu_seqlens.shape[0] - 1 + for i in range(num_seqs): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + qi = q[start:end].transpose(0, 1).unsqueeze(0) # [1, heads, seq, dim] + ki = k[start:end].transpose(0, 1).unsqueeze(0) + vi = v[start:end].transpose(0, 1).unsqueeze(0) + oi = F.scaled_dot_product_attention(qi, ki, vi) + output[start:end] = oi.squeeze(0).transpose(0, 1) + + output = output.reshape(seq_len, self.embed_dim) + return self.out_proj(output) + + +class Qwen3VisionBlock(nn.Module): + """Single vision transformer block.""" + + def __init__( + self, + dim: int, + num_heads: int, + intermediate_dim: int, + hidden_act: str = "silu", + norm_eps: float = 1e-6, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, eps=norm_eps) + self.norm2 = nn.LayerNorm(dim, eps=norm_eps) + self.attn = Qwen3VisionAttention(embed_dim=dim, num_heads=num_heads) + self.mlp = Qwen3VisionMLP( + dim, intermediate_dim, hidden_act=hidden_act, bias=True + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + ) + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen3VLVisionPatchMerger(nn.Module): + """Merges spatial patches to reduce sequence length. + + Groups ``spatial_merge_size ** 2`` consecutive patch tokens and projects + them to the language model hidden dimension. + """ + + def __init__( + self, + dim: int, + context_dim: int, + spatial_merge_size: int = 2, + use_postshuffle_norm: bool = False, + norm_eps: float = 1e-6, + ): + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm( + self.hidden_size if use_postshuffle_norm else context_dim, eps=norm_eps + ) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, dim, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_postshuffle_norm: + x = self.norm(x.view(-1, self.hidden_size)) + else: + x = self.norm(x).view(-1, self.hidden_size) + x = self.act_fn(self.linear_fc1(x)) + return self.linear_fc2(x) + + +class Qwen3VLVisionModel(nn.Module): + """Complete vision encoder for Qwen3-VL. + + Produces patch embeddings from raw pixel values, applies a stack of + vision transformer blocks with 3D rotary embeddings, then merges + spatial patches. Supports "deep stack" where intermediate layer + outputs are captured and concatenated to the final output. + """ + + def __init__( + self, + depth: int = 27, + hidden_size: int = 1152, + hidden_act: str = "gelu_pytorch_tanh", + intermediate_size: int = 4304, + num_heads: int = 16, + in_channels: int = 3, + patch_size: int = 16, + spatial_merge_size: int = 2, + temporal_patch_size: int = 2, + out_hidden_size: int = 3584, + num_position_embeddings: int = 2304, + deepstack_visual_indexes: Optional[List[int]] = None, + norm_eps: float = 1e-6, + ): + super().__init__() + if deepstack_visual_indexes is None: + deepstack_visual_indexes = [8, 16, 24] + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_position_embeddings = num_position_embeddings + self.num_grid_per_side = int(num_position_embeddings**0.5) + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.deepstack_visual_indexes = deepstack_visual_indexes + # Total output dim = out_hidden_size * (1 main + N deepstack mergers) + self.out_hidden_size = out_hidden_size * (1 + len(deepstack_visual_indexes)) + + self.patch_embed = Qwen3VLVisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + embed_dim=hidden_size, + ) + + self.pos_embed = nn.Embedding(num_position_embeddings, hidden_size) + + head_dim = hidden_size // num_heads + self._init_rope_cache(head_dim) + + self.blocks = nn.ModuleList( + [ + Qwen3VisionBlock( + dim=hidden_size, + num_heads=num_heads, + intermediate_dim=intermediate_size, + hidden_act=hidden_act, + norm_eps=norm_eps, + ) + for _ in range(depth) + ] + ) + + self.merger = Qwen3VLVisionPatchMerger( + dim=out_hidden_size, + context_dim=hidden_size, + spatial_merge_size=spatial_merge_size, + norm_eps=norm_eps, + ) + + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3VLVisionPatchMerger( + dim=out_hidden_size, + context_dim=hidden_size, + spatial_merge_size=spatial_merge_size, + use_postshuffle_norm=True, + norm_eps=norm_eps, + ) + for _ in range(len(deepstack_visual_indexes)) + ] + ) + + def _init_rope_cache(self, head_dim: int, max_grid_size: int = 8192): + """Precompute cos/sin cache for 2D rotary embeddings.""" + rotary_dim = head_dim // 2 + inv_freq = 1.0 / ( + 10000.0 + ** (torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim) + ) + t = torch.arange(max_grid_size, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("cos_cache", torch.cos(freqs), persistent=False) + self.register_buffer("sin_cache", torch.sin(freqs), persistent=False) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + # -- Rotary position embedding helpers -- + + @staticmethod + def _rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor: + """Compute 2D rotary position IDs for a grid of *h* x *w* patches. + + The patches are re-ordered to group ``spatial_merge_size ** 2`` + neighbours together (matching the merger's token order). + + Returns tensor of shape ``[h*w, 2]`` with ``(height_pos, width_pos)``. + """ + merge = spatial_merge_size + h_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + w_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + + h_ids = h_ids.reshape(h // merge, merge, w // merge, merge) + w_ids = w_ids.reshape(h // merge, merge, w // merge, merge) + + h_ids = h_ids.permute(0, 2, 1, 3).flatten() + w_ids = w_ids.permute(0, 2, 1, 3).flatten() + + return torch.stack([h_ids, w_ids], dim=-1) + + def rot_pos_emb( + self, grid_thw: List[List[int]] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute rotary pos-emb cos/sin for all images/videos in the batch.""" + pos_ids = [] + for t, h, w in grid_thw: + base = self._rot_pos_ids(h, w, self.spatial_merge_size) + pos_ids.append(base if t == 1 else base.repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True) + cos_combined = self.cos_cache[pos_ids].flatten(1) + sin_combined = self.sin_cache[pos_ids].flatten(1) + return cos_combined, sin_combined + + # -- Position embedding interpolation -- + + def _get_interpolation_indices(self, dim_size: int) -> np.ndarray: + indices = (np.arange(dim_size, dtype=np.float32) + 0.5) * ( + self.num_grid_per_side / dim_size + ) - 0.5 + return np.clip(indices, 0, self.num_grid_per_side - 1) + + def _calculate_indices_and_weights( + self, h_idxs: np.ndarray, w_idxs: np.ndarray + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """Compute bilinear interpolation indices and weights.""" + side = self.num_grid_per_side + h_f = np.floor(h_idxs).astype(np.int64) + h_c = np.clip(h_f + 1, 0, side - 1) + dh = h_idxs - h_f + w_f = np.floor(w_idxs).astype(np.int64) + w_c = np.clip(w_f + 1, 0, side - 1) + dw = w_idxs - w_f + + indices = [ + (h_f[:, None] * side + w_f).flatten(), + (h_f[:, None] * side + w_c).flatten(), + (h_c[:, None] * side + w_f).flatten(), + (h_c[:, None] * side + w_c).flatten(), + ] + weights = [ + ((1 - dh)[:, None] * (1 - dw)).flatten(), + ((1 - dh)[:, None] * dw).flatten(), + (dh[:, None] * (1 - dw)).flatten(), + (dh[:, None] * dw).flatten(), + ] + return indices, weights + + def _get_position_embedding( + self, + patch_pos_embeds: List[torch.Tensor], + grid_ts: List[int], + grid_hs: List[int], + grid_ws: List[int], + ) -> torch.Tensor: + """Tile and reorganize position embeddings to align with the merged token order.""" + result_parts = [] + merge = self.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge, merge, w // merge, merge, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + result_parts.append(pos_embed) + return torch.cat(result_parts, dim=0) + + def fast_pos_embed_interpolate(self, grid_thw: torch.Tensor) -> torch.Tensor: + """Interpolate position embeddings via bilinear interpolation.""" + grid_thw_cpu = grid_thw.cpu().numpy() + temporal_dims = grid_thw_cpu[:, 0].tolist() + height_dims = grid_thw_cpu[:, 1].tolist() + width_dims = grid_thw_cpu[:, 2].tolist() + + device = self.pos_embed.weight.device + dtype = self.pos_embed.weight.dtype + + patches_size = [h * w for h, w in zip(height_dims, width_dims)] + total_patches = sum(patches_size) + all_indices_np = np.zeros((4, total_patches), dtype=np.int64) + all_weights_np = np.zeros((4, total_patches), dtype=np.float32) + + current_idx = 0 + for _t, h, w in zip(temporal_dims, height_dims, width_dims): + h_idxs = self._get_interpolation_indices(h) + w_idxs = self._get_interpolation_indices(w) + indices, weights = self._calculate_indices_and_weights(h_idxs, w_idxs) + end_idx = current_idx + h * w + for i in range(4): + all_indices_np[i, current_idx:end_idx] = indices[i] + all_weights_np[i, current_idx:end_idx] = weights[i] + current_idx = end_idx + + idx_tensor = torch.from_numpy(all_indices_np).to(device) + weight_tensor = torch.from_numpy(all_weights_np).to(dtype=dtype, device=device) + + pos_embeds = self.pos_embed(idx_tensor.view(-1)) + pos_embeds = pos_embeds.view(4, total_patches, -1) + patch_pos_embeds = (pos_embeds * weight_tensor.unsqueeze(-1)).sum(dim=0) + patch_pos_embeds = patch_pos_embeds.split(patches_size) + return self._get_position_embedding( + list(patch_pos_embeds), temporal_dims, height_dims, width_dims + ) + + # -- Forward -- + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + """Run the vision encoder. + + Args: + x: Pixel values, shape ``[total_patches, patch_dim]``. + grid_thw: Grid dimensions ``[num_images, 3]`` with ``(T, H, W)``. + + Returns: + Vision features of shape + ``[num_merged_tokens, out_hidden_size * (1 + num_deepstack)]``. + """ + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + else: + grid_thw_list = grid_thw.tolist() + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + x += pos_embeds + + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) + + cu_seqlens = _compute_cu_seqlens_from_grid(grid_thw) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) + + deepstack_features = [] + ds_idx = 0 + + for layer_num, blk in enumerate(self.blocks): + x = blk(x, cu_seqlens, rotary_pos_emb_cos, rotary_pos_emb_sin) + + if layer_num in self.deepstack_visual_indexes: + # x is [total_tokens, hidden]. The merger expects the last + # dim to be context_dim so it can group spatial_merge_size^2 + # tokens; reshape to [total_tokens, 1, hidden] so that the + # `.view(-1, hidden_size)` inside the merger collapses the + # spatial merge correctly. + ds_feat = self.deepstack_merger_list[ds_idx](x.unsqueeze(1)) + deepstack_features.append(ds_feat) + ds_idx += 1 + + x = self.merger(x.unsqueeze(1)) + + # Concatenate main + deepstack features along the feature dimension. + # Result: [num_merged_tokens, out_hidden_size * (1 + num_deepstack)] + hidden_states = torch.cat([x] + deepstack_features, dim=-1) + return hidden_states + + +def _compute_cu_seqlens_from_grid(grid_thw: torch.Tensor) -> torch.Tensor: + """Compute cumulative sequence lengths from grid dimensions.""" + grid_np = grid_thw.cpu().numpy() + seq_lens = (grid_np[:, 0] * grid_np[:, 1] * grid_np[:, 2]).astype(np.int32) + cu_seqlens = np.concatenate([[0], np.cumsum(seq_lens)]) + return torch.tensor(cu_seqlens, dtype=torch.int32) + + +def _build_cos_sin_cache( + head_dim: int, + rope_theta: float, + max_pos: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Build a [max_pos, head_dim] cos/sin cache for M-RoPE. + + Layout: first ``head_dim // 2`` columns are cos values, second half are sin. + Each row corresponds to one position index. + """ + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) + ) + t = torch.arange(max_pos, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) # [max_pos, head_dim // 2] + return torch.cat([torch.cos(freqs), torch.sin(freqs)], dim=-1).to(dtype) + + +def get_rope_index( + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor], + image_token_id: int, + vision_start_token_id: int, + spatial_merge_size: int, +) -> Tuple[torch.Tensor, int]: + """Compute M-RoPE 3-D position IDs for one sequence. + + For text tokens all three (temporal, height, width) indices are equal to + the sequential counter. For image tokens the indices follow the spatial + grid ``(t, h, w)``. + + Args: + input_ids: Token IDs for one sequence, shape ``[T]``. + image_grid_thw: Grid dimensions for every image in the sequence, + shape ``[num_images, 3]``. ``None`` when there are no images. + image_token_id: Token ID used as placeholder for image patches. + vision_start_token_id: Token ID that precedes each image block. + spatial_merge_size: Number of patches merged per spatial dimension + (e.g. 2 → 2x2 merge, so llm_grid_h = H // 2). + + Returns: + ``(position_ids, mrope_position_delta)`` where ``position_ids`` has + shape ``[3, T]`` and ``mrope_position_delta`` is a Python ``int`` + equal to ``max_position_used + 1 - T``. + """ + total_tokens = input_ids.shape[0] + device = input_ids.device + position_ids = torch.zeros(3, total_tokens, dtype=torch.long, device=device) + + if image_grid_thw is None or image_grid_thw.shape[0] == 0: + pos = torch.arange(total_tokens, dtype=torch.long, device=device) + position_ids[0] = pos + position_ids[1] = pos + position_ids[2] = pos + return position_ids, 0 + + input_ids_cpu = input_ids.cpu().tolist() + grid_thw_list = image_grid_thw.cpu().tolist() + + llm_pos_ids_start = 0 + image_idx = 0 + i = 0 + + while i < total_tokens: + token = input_ids_cpu[i] + + if token == vision_start_token_id and image_idx < len(grid_thw_list): + # The vision_start token itself gets a regular sequential position. + position_ids[:, i] = llm_pos_ids_start + llm_pos_ids_start += 1 + i += 1 + + # Compute LLM-side grid dimensions (after spatial merging). + t_g = int(grid_thw_list[image_idx][0]) + h_g = int(grid_thw_list[image_idx][1]) + w_g = int(grid_thw_list[image_idx][2]) + llm_grid_t = t_g + llm_grid_h = h_g // spatial_merge_size + llm_grid_w = w_g // spatial_merge_size + num_image_tokens = llm_grid_t * llm_grid_h * llm_grid_w + + # Build per-patch 3-D indices. + t_idx = ( + torch.arange(llm_grid_t, device=device) + .view(-1, 1, 1) + .expand(-1, llm_grid_h, llm_grid_w) + .flatten() + ) + h_idx = ( + torch.arange(llm_grid_h, device=device) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_idx = ( + torch.arange(llm_grid_w, device=device) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + + img_start = i + img_end = i + num_image_tokens + position_ids[0, img_start:img_end] = t_idx + llm_pos_ids_start + position_ids[1, img_start:img_end] = h_idx + llm_pos_ids_start + position_ids[2, img_start:img_end] = w_idx + llm_pos_ids_start + + llm_pos_ids_start += max(llm_grid_t, llm_grid_h, llm_grid_w) + i += num_image_tokens + image_idx += 1 + else: + # Text token (including vision_end and all non-image tokens). + position_ids[:, i] = llm_pos_ids_start + llm_pos_ids_start += 1 + i += 1 + + mrope_position_delta = llm_pos_ids_start - total_tokens + return position_ids, mrope_position_delta + + +# --------------------------------------------------------------------------- +# Text Decoder (Language Model) +# --------------------------------------------------------------------------- + + +class Qwen3VLAttention(nn.Module): + """Attention layer for the Qwen3-VL text decoder. + + Uses QK-norm (per-head RMSNorm on Q and K before RoPE) and + :class:`RadixAttention` for KV-cached inference. Applies + interleaved M-RoPE with a precomputed cos/sin cache. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + layer_id: int, + rope_theta: float = 5_000_000.0, + rms_norm_eps: float = 1e-6, + mrope_section: Tuple[int, int, int] = (24, 20, 20), + mrope_interleaved: bool = True, + max_position_embeddings: int = 32768, + ): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.q_size = num_heads * head_dim + self.kv_size = num_kv_heads * head_dim + self.scaling = head_dim**-0.5 + self.mrope_section = list(mrope_section) + self.mrope_interleaved = mrope_interleaved + + # Fused QKV projection + self.qkv_proj = nn.Linear( + hidden_size, self.q_size + 2 * self.kv_size, bias=False + ) + + # Output projection + self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False) + + # QK normalization + self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps) + + # Precomputed M-RoPE cos/sin cache: [max_pos, head_dim] + cos_sin = _build_cos_sin_cache( + head_dim, rope_theta, max_position_embeddings, torch.float32 + ) + self.register_buffer("cos_sin_cache", cos_sin, persistent=False) + + # Radix attention (single-GPU: heads == tp_heads) + self.attn = RadixAttention( + num_heads=num_heads, + head_dim=head_dim, + scaling=self.scaling, + num_kv_heads=num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: "ForwardBatch", + ) -> torch.Tensor: + qkv = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Per-head QK normalization + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)) + + # Apply M-RoPE. positions is [3, T] for prefill (3-D) or may arrive + # as [T] for purely text-only batches; expand to [3, T] in that case. + if positions.ndim == 1: + positions = positions.unsqueeze(0).expand(3, -1) + q, k = apply_mrope( + q, + k, + positions, + self.cos_sin_cache.to(q.dtype), + self.mrope_section, + self.mrope_interleaved, + ) + + q = q.reshape(-1, self.q_size) + k = k.reshape(-1, self.kv_size) + + # Attention with KV cache + attn_output = self.attn(q, k, v, forward_batch) + return self.o_proj(attn_output) + + +class Qwen3VLDecoderLayer(nn.Module): + """Single decoder layer for the Qwen3-VL text model.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + intermediate_size: int, + layer_id: int, + rope_theta: float = 5_000_000.0, + rms_norm_eps: float = 1e-6, + mrope_section: Tuple[int, int, int] = (24, 20, 20), + mrope_interleaved: bool = True, + max_position_embeddings: int = 32768, + ): + super().__init__() + self.self_attn = Qwen3VLAttention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + layer_id=layer_id, + rope_theta=rope_theta, + rms_norm_eps=rms_norm_eps, + mrope_section=mrope_section, + mrope_interleaved=mrope_interleaved, + max_position_embeddings=max_position_embeddings, + ) + self.mlp = MLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation="silu", + use_fused_gate_up_proj=True, + use_bias_gate_up=False, + use_bias_down=False, + ) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: "ForwardBatch", + deepstack_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Self-attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states, forward_batch) + hidden_states = residual + hidden_states + + # Add deepstack embeddings after residual (matches HF ordering) + if deepstack_embeds is not None: + hidden_states = hidden_states + deepstack_embeds + + # MLP + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen3VLTextModel(nn.Module): + """Qwen3-VL text backbone (embedding + decoder layers + final norm).""" + + def __init__( + self, + vocab_size: int = 151936, + hidden_size: int = 4096, + intermediate_size: int = 22016, + num_hidden_layers: int = 32, + num_attention_heads: int = 32, + num_key_value_heads: int = 32, + head_dim: int = 128, + rope_theta: float = 5_000_000.0, + rms_norm_eps: float = 1e-6, + mrope_section: Tuple[int, int, int] = (24, 20, 20), + mrope_interleaved: bool = True, + max_position_embeddings: int = 32768, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + + self.embed_tokens = nn.Embedding(vocab_size, hidden_size) + + self.layers = nn.ModuleList( + [ + Qwen3VLDecoderLayer( + hidden_size=hidden_size, + num_heads=num_attention_heads, + num_kv_heads=num_key_value_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + layer_id=layer_id, + rope_theta=rope_theta, + rms_norm_eps=rms_norm_eps, + mrope_section=mrope_section, + mrope_interleaved=mrope_interleaved, + max_position_embeddings=max_position_embeddings, + ) + for layer_id in range(num_hidden_layers) + ] + ) + + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: "ForwardBatch", + input_embeds: Optional[torch.Tensor] = None, + input_deepstack_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + for layer_idx, layer in enumerate(self.layers): + ds_embeds = _get_deepstack_embeds( + layer_idx, input_deepstack_embeds, self.hidden_size + ) + hidden_states = layer( + positions, + hidden_states, + forward_batch, + deepstack_embeds=ds_embeds, + ) + + return self.norm(hidden_states) + + +def _get_deepstack_embeds( + layer_idx: int, + input_deepstack_embeds: Optional[torch.Tensor], + hidden_size: int, +) -> Optional[torch.Tensor]: + """Extract deepstack embeddings for a specific decoder layer.""" + if input_deepstack_embeds is None: + return None + num_deepstack = input_deepstack_embeds.shape[-1] // hidden_size + if layer_idx >= num_deepstack: + return None + start = hidden_size * layer_idx + return input_deepstack_embeds[:, start : start + hidden_size] + + +# --------------------------------------------------------------------------- +# Full Model: Qwen3VLForConditionalGeneration +# --------------------------------------------------------------------------- + + +class Qwen3VLForConditionalGeneration(nn.Module): + """Qwen3-VL multimodal model for conditional generation. + + Combines a vision encoder and text decoder. During prefill, image/video + tokens are replaced with visual features from the vision encoder. + During decode, the model runs only the text decoder. + + Forward interface:: + + logits = model.forward(input_ids, positions, forward_batch) + """ + + def __init__(self, config) -> None: + super().__init__() + self.config = config + + text_config = getattr(config, "text_config", config) + vision_config = getattr(config, "vision_config", None) + + # Vision encoder + if vision_config is not None: + self.visual = Qwen3VLVisionModel( + depth=getattr(vision_config, "depth", 27), + hidden_size=getattr(vision_config, "hidden_size", 1152), + hidden_act=getattr(vision_config, "hidden_act", "gelu_pytorch_tanh"), + intermediate_size=getattr(vision_config, "intermediate_size", 4304), + num_heads=getattr(vision_config, "num_heads", 16), + in_channels=getattr(vision_config, "in_channels", 3), + patch_size=getattr(vision_config, "patch_size", 16), + spatial_merge_size=getattr(vision_config, "spatial_merge_size", 2), + temporal_patch_size=getattr(vision_config, "temporal_patch_size", 2), + out_hidden_size=getattr(vision_config, "out_hidden_size", 3584), + num_position_embeddings=getattr( + vision_config, "num_position_embeddings", 2304 + ), + deepstack_visual_indexes=getattr( + vision_config, "deepstack_visual_indexes", [8, 16, 24] + ), + norm_eps=getattr(text_config, "rms_norm_eps", 1e-6), + ) + else: + self.visual = None + + # Text decoder + hidden_size = getattr(text_config, "hidden_size", 4096) + vocab_size = getattr(text_config, "vocab_size", 151936) + + # M-RoPE configuration -- mrope_section lives inside rope_scaling, + # NOT as a top-level attribute of text_config. + rope_scaling = getattr(text_config, "rope_scaling", None) or {} + if isinstance(rope_scaling, dict): + mrope_section = rope_scaling.get("mrope_section", [24, 20, 20]) + mrope_interleaved = rope_scaling.get("mrope_interleaved", True) + else: + mrope_section = getattr(rope_scaling, "mrope_section", [24, 20, 20]) + mrope_interleaved = getattr(rope_scaling, "mrope_interleaved", True) + max_position_embeddings = getattr(text_config, "max_position_embeddings", 32768) + + self.model = Qwen3VLTextModel( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=getattr(text_config, "intermediate_size", 22016), + num_hidden_layers=getattr(text_config, "num_hidden_layers", 32), + num_attention_heads=getattr(text_config, "num_attention_heads", 32), + num_key_value_heads=getattr(text_config, "num_key_value_heads", 32), + head_dim=getattr(text_config, "head_dim", 128), + rope_theta=getattr(text_config, "rope_theta", 5_000_000.0), + rms_norm_eps=getattr(text_config, "rms_norm_eps", 1e-6), + mrope_section=tuple(mrope_section), + mrope_interleaved=bool(mrope_interleaved), + max_position_embeddings=max_position_embeddings, + ) + + # LM head — following sglang's pattern: always use lm_head.weight + # for matmul in forward(), so it works whether lm_head is nn.Embedding + # (tied) or nn.Linear (untied). + tie_word_embeddings = getattr(config, "tie_word_embeddings", False) + if tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) + + # Token IDs for multimodal + self.image_token_id = getattr(config, "image_token_id", 151655) + self.video_token_id = getattr(config, "video_token_id", 151656) + self.vision_start_token_id = getattr(config, "vision_start_token_id", 151652) + + # Spatial merge size (needed for get_rope_index) + self.spatial_merge_size = ( + getattr(vision_config, "spatial_merge_size", 2) + if vision_config is not None + else 2 + ) + + # Deepstack config + if vision_config is not None: + ds_indexes = getattr(vision_config, "deepstack_visual_indexes", [8, 16, 24]) + self.num_deepstack_embeddings = len(ds_indexes) + else: + self.num_deepstack_embeddings = 0 + + self._hidden_size = hidden_size + + def get_input_embeddings(self) -> nn.Module: + return self.model.embed_tokens + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: "ForwardBatch", + ) -> torch.Tensor: + """Run forward pass for Qwen3-VL. + + Args: + input_ids: Flattened input token IDs, shape ``[num_tokens]``. + positions: Position IDs, shape ``[num_tokens]`` (1-D, from model + runner). Overridden internally with 3-D M-RoPE positions. + forward_batch: :class:`ForwardBatch` with attention metadata. + + Returns: + Logits tensor of shape ``[num_tokens, vocab_size]``. + """ + pixel_values = getattr(forward_batch, "pixel_values", None) + image_grid_thw = getattr(forward_batch, "image_grid_thw", None) + + # ------------------------------------------------------------------ + # Build 3-D M-RoPE positions + # ------------------------------------------------------------------ + if forward_batch.forward_mode.is_extend(): + # Prefill: compute per-sequence 3-D position IDs from input_ids + # and image grids, then store per-request deltas for future decode. + mrope_positions_list: List[torch.Tensor] = [] + deltas: List[int] = [] + image_idx_offset = 0 + + for i in range(forward_batch.batch_size): + start = int(forward_batch.extend_start_loc[i].item()) + length = int(forward_batch.extend_seq_lens[i].item()) + seq_ids = input_ids[start : start + length] + + # Determine how many images belong to this sequence. + num_img = int((seq_ids == self.vision_start_token_id).sum().item()) + if image_grid_thw is not None and num_img > 0: + thw_seq = image_grid_thw[ + image_idx_offset : image_idx_offset + num_img + ] + image_idx_offset += num_img + else: + thw_seq = None + + pos3d, delta = get_rope_index( + seq_ids, + thw_seq, + self.image_token_id, + self.vision_start_token_id, + self.spatial_merge_size, + ) + mrope_positions_list.append(pos3d) + deltas.append(delta) + + # Concatenate across sequences: [3, total_extend_tokens] + positions = torch.cat(mrope_positions_list, dim=1) + forward_batch.mrope_position_deltas = torch.tensor( + deltas, dtype=torch.int64, device=input_ids.device + ) + else: + # Decode: each sequence emits exactly one token. Apply the stored + # per-request delta so the position matches the image extent. + stored_deltas = getattr(forward_batch, "mrope_position_deltas", None) + if stored_deltas is not None: + pos_1d = forward_batch.positions + stored_deltas + else: + pos_1d = forward_batch.positions + positions = pos_1d.unsqueeze(0).expand(3, -1) # [3, batch_size] + + input_embeds = None + input_deepstack_embeds = None + + if ( + pixel_values is not None + and image_grid_thw is not None + and self.visual is not None + and not forward_batch.forward_mode.is_decode() + ): + # Run vision encoder + vision_features = self.visual(pixel_values, grid_thw=image_grid_thw) + + # Separate main embeddings and deepstack embeddings + if self.num_deepstack_embeddings > 0: + vision_embeds = vision_features[:, : self._hidden_size] + deepstack_embeds = vision_features[:, self._hidden_size :] + else: + vision_embeds = vision_features + deepstack_embeds = None + + # Get text embeddings and replace image tokens with vision features + input_embeds = self.model.embed_tokens(input_ids) + image_mask = input_ids == self.image_token_id + if image_mask.any(): + input_embeds[image_mask] = vision_embeds.to(input_embeds.dtype) + + # Build per-token deepstack embeddings + if deepstack_embeds is not None and image_mask.any(): + input_deepstack_embeds = torch.zeros( + input_embeds.shape[0], + deepstack_embeds.shape[-1], + dtype=input_embeds.dtype, + device=input_embeds.device, + ) + input_deepstack_embeds[image_mask] = deepstack_embeds.to( + input_embeds.dtype + ) + + # Text decoder + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds=input_embeds, + input_deepstack_embeds=input_deepstack_embeds, + ) + + # Prune hidden_states before lm_head to avoid a wasteful + # [total_tokens, vocab] matmul during prefill. Following sglang's + # LogitsProcessor._get_pruned_states(): in extend mode only keep + # the last token of each sequence; in decode mode all rows are + # already one-per-sequence. + if forward_batch.forward_mode.is_extend(): + if ( + forward_batch.extend_start_loc is not None + and forward_batch.extend_seq_lens is not None + ): + last_index = ( + forward_batch.extend_start_loc + forward_batch.extend_seq_lens - 1 + ).long() + hidden_states = hidden_states[last_index] + else: + hidden_states = hidden_states[-1:] + + # LM head: always use weight matrix directly for the linear + # projection. Works for both nn.Embedding (tied) and nn.Linear + # (untied). Matches sglang LogitsProcessor._compute_lm_head(). + logits = torch.matmul( + hidden_states.to(self.lm_head.weight.dtype), + self.lm_head.weight.T, + ) + + # Return LogitsProcessorOutput so that ModelRunner._process_logits + # skips redundant last-token gathering. + from pymllm.executor.model_runner import LogitsProcessorOutput + + return LogitsProcessorOutput(next_token_logits=logits) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: + """Load weights from a HuggingFace checkpoint. + + Handles weight name remapping between HuggingFace Qwen3-VL + checkpoints and this model's parameter names. + """ + stacked_params_mapping = [ + # (param_name, weight_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".up_proj", 1), + (".gate_up_proj", ".gate_proj", 0), + ] + + params_dict = dict(self.named_parameters()) + + tie_word_embeddings = getattr(self.config, "tie_word_embeddings", False) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + # When weights are tied, lm_head.weight is the same tensor as + # embed_tokens.weight — skip the duplicate from the checkpoint. + if tie_word_embeddings and "lm_head.weight" in name: + continue + + name = _remap_weight_name(name) + + # Handle language model stacked parameters (QKV, gate_up) + handled = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name or "visual" in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + _load_stacked_weight(params_dict[name], loaded_weight, shard_id) + handled = True + break + + if handled: + continue + + # Handle vision encoder QKV stacking + if "visual" in name: + for qkv_key in (".attn.q.", ".attn.k.", ".attn.v."): + if qkv_key not in name: + continue + qkv_name = name.replace(qkv_key, ".attn.qkv_proj.") + if qkv_name in params_dict: + shard = {"q": 0, "k": 1, "v": 2}[qkv_key[-2]] + _load_vision_qkv_weight( + params_dict[qkv_name], loaded_weight, shard + ) + handled = True + break + + if handled: + continue + + # Direct parameter loading + if name in params_dict: + param = params_dict[name] + if param.data.shape == loaded_weight.shape: + param.data.copy_(loaded_weight) + else: + logger.warning( + "Shape mismatch: param %s (%s) vs loaded (%s), skipping.", + name, + param.data.shape, + loaded_weight.shape, + ) + + +# --------------------------------------------------------------------------- +# Weight loading helpers +# --------------------------------------------------------------------------- + + +def _remap_weight_name(name: str) -> str: + """Remap HuggingFace weight names to pymllm parameter names.""" + # transformers >= v4.52: model.language_model.* -> model.* + if name.startswith("model.language_model."): + name = name.replace("model.language_model.", "model.", 1) + # model.visual.* -> visual.* + elif name.startswith("model.visual."): + name = name.replace("model.visual.", "visual.", 1) + + # Vision attention QKV renaming (fused weights in checkpoint) + if "visual" in name: + name = name.replace("attn.qkv.", "attn.qkv_proj.") + + return name + + +def _load_stacked_weight( + param: nn.Parameter, + loaded_weight: torch.Tensor, + shard_id, +) -> None: + """Load one shard (q/k/v or gate/up) into a fused parameter. + + For QKV with GQA (grouped-query attention), Q has a different size + from K and V. The fused layout is ``[Q, K, V]`` where + ``Q_size = total - 2 * KV_size``. We must use cumulative offsets + rather than ``idx * shard_size`` to handle the asymmetry correctly. + """ + if isinstance(shard_id, str): + # QKV fused layout: [Q, K, V] + # Q may have a different size from K/V (GQA). + total_size = param.data.shape[0] + shard_size = loaded_weight.shape[0] + if shard_id == "q": + param.data[0:shard_size].copy_(loaded_weight) + elif shard_id == "k": + kv_size = shard_size + q_size = total_size - 2 * kv_size + param.data[q_size : q_size + kv_size].copy_(loaded_weight) + elif shard_id == "v": + kv_size = shard_size + q_size = total_size - 2 * kv_size + param.data[q_size + kv_size : q_size + 2 * kv_size].copy_( + loaded_weight + ) + else: + # gate_up: 0 -> gate, 1 -> up (same size, idx*size is correct) + shard_size = loaded_weight.shape[0] + param.data[shard_id * shard_size : (shard_id + 1) * shard_size].copy_( + loaded_weight + ) + + +def _load_vision_qkv_weight( + param: nn.Parameter, + loaded_weight: torch.Tensor, + shard_idx: int, +) -> None: + """Load a Q, K, or V weight shard into a fused QKV parameter.""" + shard_size = param.data.shape[0] // 3 + start = shard_idx * shard_size + param.data[start : start + shard_size].copy_(loaded_weight) diff --git a/pymllm/orchestrator/__init__.py b/pymllm/orchestrator/__init__.py new file mode 100644 index 000000000..f1716d794 --- /dev/null +++ b/pymllm/orchestrator/__init__.py @@ -0,0 +1,48 @@ +"""Orchestrator module for distributed computation.""" + +from pymllm.orchestrator.group_coordinator import ( + GroupCoordinator, + divide, + split_tensor_along_dim, +) +from pymllm.orchestrator.parallel_state import ( + data_parallel_all_reduce, + get_data_parallel_rank, + get_data_parallel_world_size, + get_dp_group, + get_pipeline_model_parallel_rank, + get_pipeline_model_parallel_world_size, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, + initialize_model_parallel, + model_parallel_is_initialized, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) + +__all__ = [ + # GroupCoordinator + "GroupCoordinator", + "divide", + "split_tensor_along_dim", + # TP + "get_tp_group", + "get_tensor_model_parallel_rank", + "get_tensor_model_parallel_world_size", + "tensor_model_parallel_all_reduce", + "tensor_model_parallel_all_gather", + # DP + "get_dp_group", + "get_data_parallel_rank", + "get_data_parallel_world_size", + "data_parallel_all_reduce", + # PP + "get_pp_group", + "get_pipeline_model_parallel_rank", + "get_pipeline_model_parallel_world_size", + # State + "initialize_model_parallel", + "model_parallel_is_initialized", +] diff --git a/pymllm/orchestrator/cuda_ipc_transport.py b/pymllm/orchestrator/cuda_ipc_transport.py new file mode 100644 index 000000000..938132c8b --- /dev/null +++ b/pymllm/orchestrator/cuda_ipc_transport.py @@ -0,0 +1,648 @@ +""" +CUDA IPC Transport for zero-copy GPU tensor sharing between processes. + +## Background + +When sharing CUDA tensors between processes, there are two fundamentally different paths: + +1. **CPU shared memory path** (``enable_shared_queue=True, enable_cuda_ipc=False``): + GPU tensors are moved to CPU / POSIX shared memory via ``tensor.share_memory_()``. + This is safe but incurs a GPU→CPU copy which is expensive for large vision features. + +2. **CUDA IPC path** (``enable_cuda_ipc=True``): + GPU tensors stay on GPU. PyTorch's ``storage._share_cuda_()`` yields a serialisable + IPC handle; the receiver calls ``UntypedStorage._new_shared_cuda(*handle)`` to map + the same physical GPU memory without any copy. + +These two paths are **mutually exclusive for GPU tensors**. ``enable_cuda_ipc`` takes +priority; when active the CPU-copy step in ``TensorQueue._make_tensors_shareable`` is +skipped. + +## CUDA IPC memory-leak problem and its fix + +PyTorch never releases the GPU allocation backing an IPC-exported tensor until the +*sending* process exits. If we export raw model tensors we permanently leak GPU memory. + +**Solution** (pool-based recycling via ``MmItemMemoryPool``): + +* Allocate a single, fixed-size GPU workspace (``MmItemMemoryPool``). +* For each outgoing GPU tensor, copy it into a chunk of the workspace and export the + *chunk* via IPC (the workspace is never freed; its chunks are recycled). +* After the receiving process has finished with the data it writes a sync flag + (``ShmSyncBuffer``) to signal that the chunk may be reused. +* A background recycler thread in the sender walks ``occupied_chunks`` and returns + chunks whose sync flag has been incremented back to ``available_chunks``. + +## Transport modes + +``TensorTransportMode``: +* ``"default"`` – CPU/shared-memory path; no CUDA IPC. +* ``"cuda_ipc"`` – Simple CUDA IPC: wraps GPU tensors in ``TransportProxyTensor`` + (a ``torch.Tensor`` subclass whose ``__getstate__``/``__setstate__`` use + ``_share_cuda_``). Suitable for single-process-group scenarios; incurs the + PyTorch memory-leak noted above. +* ``"cuda_ipc_pool"`` – Pool-based CUDA IPC: copies GPU tensors into a pre-allocated + ``MmItemMemoryPool`` and wraps the slice in ``CudaIpcTensorTransportProxy``. + The pool is recycled, so there is no memory leak. +""" + +from __future__ import annotations + +import fcntl +import logging +import threading +import time +from multiprocessing import shared_memory +from typing import Any, Dict, List, Literal, Optional, Tuple + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Type alias for transport mode +# --------------------------------------------------------------------------- + +TensorTransportMode = Literal["default", "cuda_ipc", "cuda_ipc_pool"] + + +# --------------------------------------------------------------------------- +# ShmSyncBuffer – a tiny POSIX shared memory float used as a sync counter +# --------------------------------------------------------------------------- + + +class ShmSyncBuffer: + """A single float32 in POSIX shared memory used as a sync counter. + + The sender resets it to 0 before exporting a chunk. The receiver + increments it (atomically under a file lock) once it has finished copying + data out of the chunk. When the value reaches the number of consumers + (``tp_size``) the sender recycles the chunk. + """ + + def __init__(self, byte_size: int = 4) -> None: + self.buffer = shared_memory.SharedMemory(create=True, size=byte_size) + self._arr = np.ndarray(1, dtype=np.float32, buffer=self.buffer.buf) + self._arr *= 0 # initialise to 0 + self.meta_data: Dict[str, Any] = { + "handle": self.buffer.name, + "shape": self._arr.shape, + "dtype": str(self._arr.dtype), + } + + # ------------------------------------------------------------------ + # Helpers consumed by the *receiver* side + # ------------------------------------------------------------------ + + @staticmethod + def open( + meta_data: Dict[str, Any], + ) -> Tuple[shared_memory.SharedMemory, np.ndarray]: + """Open an existing ShmSyncBuffer from the metadata dict.""" + shm = shared_memory.SharedMemory(name=meta_data["handle"]) + arr = np.ndarray(meta_data["shape"], dtype=meta_data["dtype"], buffer=shm.buf) + return shm, arr + + def __del__(self) -> None: + try: + self.buffer.close() + self.buffer.unlink() + except Exception: + pass + + +# Lock file used to serialise writes to sync flags across processes +_SHM_LOCK_FILE = "/tmp/pymllm_shm_wr_lock.lock" + + +def _increment_sync_flag(meta_data: Dict[str, Any]) -> None: + """Increment the sync flag by 1 under a process-level file lock.""" + shm, arr = ShmSyncBuffer.open(meta_data) + try: + open(_SHM_LOCK_FILE, "a").close() # ensure file exists + with open(_SHM_LOCK_FILE, "w+") as f: + fcntl.flock(f, fcntl.LOCK_EX) + arr += 1.0 + fcntl.flock(f, fcntl.LOCK_UN) + finally: + shm.close() + + +# --------------------------------------------------------------------------- +# MmItemMemoryChunk +# --------------------------------------------------------------------------- + + +class MmItemMemoryChunk: + """A contiguous slice of the ``MmItemMemoryPool`` workspace tensor.""" + + def __init__(self, area: Tuple[int, int], sync_flag: ShmSyncBuffer) -> None: + self.area = area + self.sync_flag = sync_flag + + @property + def mem_size(self) -> int: + return self.area[1] - self.area[0] + + @property + def start(self) -> int: + return self.area[0] + + @property + def end(self) -> int: + return self.area[1] + + def try_to_recycle(self, num_consumers: int = 1) -> bool: + """Return True if all consumers have finished and the chunk can be reused.""" + val = float(self.sync_flag._arr.item()) + logger.debug( + "[try_to_recycle] area=%s flag=%.0f consumers=%d", + self.area, + val, + num_consumers, + ) + if val >= float(num_consumers): + self.sync_flag._arr *= 0.0 # reset for next use + return True + return False + + +# --------------------------------------------------------------------------- +# MmItemMemoryPool – pre-allocated GPU workspace to avoid IPC memory leaks +# --------------------------------------------------------------------------- + + +class MmItemMemoryPool: + """Pre-allocated GPU memory pool for CUDA IPC tensor transport. + + Chunks are allocated from a contiguous ``torch.int8`` tensor on GPU. + A background thread periodically recycles chunks whose sync flags show + that all consumers have finished reading. + + Args: + memory_size: Pool size in **bytes**. + recycle_interval: How often (seconds) the recycler thread runs. + num_consumers: Number of consumer processes (tp_size). Each consumer + must increment the sync flag once before a chunk is recycled. + device: CUDA device index. + """ + + def __init__( + self, + memory_size: int, + recycle_interval: float = 0.1, + num_consumers: int = 1, + device: int = 0, + ) -> None: + self.num_consumers = num_consumers + self._recycle_interval = recycle_interval + self._lock = threading.Lock() + self._stop = False + + with torch.cuda.device(device): + self.memory_pool: torch.Tensor = torch.empty( + memory_size, dtype=torch.int8, device=f"cuda:{device}" + ).contiguous() + + init_chunk = MmItemMemoryChunk((0, memory_size), self._new_sync_buffer()) + self.available_chunks: List[MmItemMemoryChunk] = [init_chunk] + self.occupied_chunks: List[MmItemMemoryChunk] = [] + # Pool of reusable ShmSyncBuffer objects (returned from recycled chunks) + self._sync_pool: List[ShmSyncBuffer] = [] + + self._recycler = threading.Thread( + target=self._recycle_loop, + name="MmItemMemoryPoolRecycler", + daemon=True, + ) + self._recycler.start() + + logger.info( + "MmItemMemoryPool: %d MB on cuda:%d, recycle_interval=%.2fs", + memory_size // (1024 * 1024), + device, + recycle_interval, + ) + + # ------------------------------------------------------------------ + # Sync buffer management + # ------------------------------------------------------------------ + + def _new_sync_buffer(self) -> ShmSyncBuffer: + if self._sync_pool: + return self._sync_pool.pop() + return ShmSyncBuffer() + + def _return_sync_buffer(self, buf: ShmSyncBuffer) -> None: + buf._arr *= 0.0 # reset counter + self._sync_pool.append(buf) + + # ------------------------------------------------------------------ + # Allocation + # ------------------------------------------------------------------ + + def _get_available_chunk(self, src: torch.Tensor) -> Optional[MmItemMemoryChunk]: + """Best-fit allocation: find the smallest available chunk >= src size.""" + needed = src.numel() * src.element_size() + best: Optional[MmItemMemoryChunk] = None + for chunk in self.available_chunks: + if chunk.mem_size >= needed: + if best is None or chunk.mem_size < best.mem_size: + best = chunk + if best is None: + return None + + # Split the selected chunk + occupied_area = (best.start, best.start + needed) + occupied = MmItemMemoryChunk(occupied_area, best.sync_flag) + self.occupied_chunks.append(occupied) + self.available_chunks.remove(best) + + remainder = (occupied.end, best.end) + if remainder[0] < remainder[1]: + split = MmItemMemoryChunk(remainder, self._new_sync_buffer()) + self.available_chunks.append(split) + + return occupied + + def get_slice_with_flag( + self, src: torch.Tensor + ) -> Tuple[Optional[Dict[str, Any]], Optional[torch.Tensor]]: + """Allocate a pool slice for *src* and return ``(sync_flag_meta, slice_tensor)``. + + Thread-safe. Returns ``(None, None)`` if the pool is full. + """ + with self._lock: + chunk = self._get_available_chunk(src) + if chunk is None: + logger.warning( + "MmItemMemoryPool full (%d occupied, %d available); " + "falling back to CPU transport", + len(self.occupied_chunks), + len(self.available_chunks), + ) + return None, None + pool_slice = self.memory_pool[chunk.start : chunk.end] + return chunk.sync_flag.meta_data, pool_slice + + # ------------------------------------------------------------------ + # Recycling + # ------------------------------------------------------------------ + + def _recycle_loop(self) -> None: + while not self._stop: + try: + with self._lock: + self._recycle_chunks() + self._merge_chunks() + except Exception as exc: + logger.warning( + "MmItemMemoryPool recycler error: %s", exc, exc_info=True + ) + time.sleep(self._recycle_interval) + + def _recycle_chunks(self) -> None: + new_occupied: List[MmItemMemoryChunk] = [] + for chunk in self.occupied_chunks: + if chunk.try_to_recycle(self.num_consumers): + self._return_sync_buffer(chunk.sync_flag) + chunk.sync_flag = self._new_sync_buffer() + self.available_chunks.append(chunk) + else: + new_occupied.append(chunk) + self.occupied_chunks = new_occupied + + def _merge_chunks(self) -> None: + """Coalesce adjacent free chunks to reduce fragmentation.""" + merged: List[MmItemMemoryChunk] = [] + for chunk in sorted(self.available_chunks, key=lambda c: c.start): + if merged and merged[-1].end == chunk.start: + prev = merged.pop() + self._return_sync_buffer(chunk.sync_flag) + merged.append( + MmItemMemoryChunk((prev.start, chunk.end), prev.sync_flag) + ) + else: + merged.append(chunk) + self.available_chunks = merged + + def shutdown(self) -> None: + self._stop = True + if self._recycler.is_alive(): + self._recycler.join(timeout=2.0) + + +# --------------------------------------------------------------------------- +# CudaIpcTensorTransportProxy – pool-based CUDA IPC proxy object +# --------------------------------------------------------------------------- + + +class CudaIpcTensorTransportProxy: + """Proxy that carries a CUDA IPC handle for a pool-slice tensor. + + The *sender* process: + 1. Copies the source tensor into a ``MmItemMemoryPool`` slice (int8 view). + 2. Wraps the slice in this proxy, which captures the CUDA IPC handle via + ``storage._share_cuda_()``. + 3. Sends the proxy through ``multiprocessing.Queue`` (pickle). + + The *receiver* process: + 1. Calls :meth:`reconstruct_on_device` to map the IPC memory and copy it + into a fresh local tensor. + 2. The copy increments the sync flag, allowing the sender's recycler to + reclaim the pool slice. + + Fallback: if ``_share_cuda_()`` fails (e.g. TP ranks), ``tensor_data`` holds + the raw tensor (which will be pickled the normal way, incurring serialization cost). + """ + + def __init__( + self, + data: torch.Tensor, + info_data: torch.Tensor, + sync_buffer_meta: Dict[str, Any], + ) -> None: + if not isinstance(data, torch.Tensor) or not isinstance( + info_data, torch.Tensor + ): + raise TypeError( + f"data and info_data must be torch.Tensors, got {type(data)}, {type(info_data)}" + ) + + self.sync_data_meta = sync_buffer_meta + self._state = self._build_state(data, info_data) + self._reconstructed: Optional[torch.Tensor] = None + self._shm: Optional[shared_memory.SharedMemory] = None + + def _build_state( + self, data: torch.Tensor, info_data: torch.Tensor + ) -> Dict[str, Any]: + try: + storage = data.untyped_storage() + handle = storage._share_cuda_() + return { + "ipc_handle": { + "handle": handle, + "shape": data.shape, + "dtype": data.dtype, + "stride": data.stride(), + "device_index": data.device.index, + "storage_offset": data.storage_offset(), + "target_shape": info_data.shape, + "target_dtype": info_data.dtype, + }, + "tensor_data": None, + } + except Exception as exc: + logger.warning( + "CudaIpcTensorTransportProxy: _share_cuda_() failed (%s); " + "falling back to direct tensor.", + exc, + ) + return {"ipc_handle": None, "tensor_data": data} + + def reconstruct_on_device(self, device_index: Optional[int] = None) -> torch.Tensor: + """Map IPC memory and copy into a new local tensor. + + This **must** be called from the *receiver* process. After the copy + the sync flag is incremented so the sender can recycle the pool chunk. + """ + if self._reconstructed is not None: + return self._reconstructed + + state = self._state + if state["ipc_handle"] is not None: + h = state["ipc_handle"] + source_device = torch.device(f"cuda:{h['device_index']}") + target_device = ( + source_device + if device_index is None + else torch.device(f"cuda:{device_index}") + ) + with torch.cuda.device(source_device): + storage = torch.UntypedStorage._new_shared_cuda(*h["handle"]) + slice_tensor = torch.empty( + 0, dtype=h["dtype"], device=source_device + ).set_( + storage, + storage_offset=h["storage_offset"], + size=h["shape"], + stride=h["stride"], + ) + + result = torch.empty( + h["target_shape"], dtype=h["target_dtype"], device=target_device + ).contiguous() + result.view(torch.int8).view(-1).copy_(slice_tensor) + + # Signal sender that the chunk can be recycled + _increment_sync_flag(self.sync_data_meta) + elif state["tensor_data"] is not None: + result = state["tensor_data"] + if device_index is not None: + result = result.to(f"cuda:{device_index}", non_blocking=True) + else: + raise RuntimeError("CudaIpcTensorTransportProxy: invalid state") + + self._reconstructed = result + return result + + +# --------------------------------------------------------------------------- +# TransportProxyTensor – simple CUDA IPC via torch.Tensor subclass + pickle +# --------------------------------------------------------------------------- + + +class TransportProxyTensor(torch.Tensor): + """A ``torch.Tensor`` subclass whose pickle uses CUDA IPC handles. + + When ``transport_mode == "cuda_ipc"`` and the tensor is on CUDA, + ``__getstate__`` exports the tensor via ``storage._share_cuda_()`` instead + of serialising the raw data. ``__setstate__`` reconstructs it in the + receiving process via ``UntypedStorage._new_shared_cuda``. + + Caveat: The underlying GPU allocation is never freed until the *sender* + process exits (PyTorch limitation). Prefer ``"cuda_ipc_pool"`` mode for + long-running services to avoid GPU memory leaks. + + When the tensor is on CPU or ``transport_mode == "default"``, the tensor + is serialised normally (pickle of raw data). + """ + + @staticmethod + def __new__( + cls, + data: torch.Tensor, + transport_mode: TensorTransportMode = "default", + ) -> "TransportProxyTensor": + if not isinstance(data, torch.Tensor): + raise TypeError(f"data must be a torch.Tensor, got {type(data)}") + instance = data.as_subclass(cls) + instance._transport_mode = transport_mode + return instance + + def __getstate__(self) -> Dict[str, Any]: + state: Dict[str, Any] = { + "transport_mode": self._transport_mode, + "tensor_data": None, + "ipc_extra": None, + } + if self._transport_mode == "cuda_ipc" and self.is_cuda: + try: + storage = self.untyped_storage() + handle = storage._share_cuda_() + state["ipc_extra"] = { + "handle": handle, + "shape": self.shape, + "dtype": self.dtype, + "stride": self.stride(), + "device_index": self.device.index, + "storage_offset": self.storage_offset(), + } + except Exception as exc: + logger.warning( + "TransportProxyTensor: _share_cuda_() failed (%s); falling back.", + exc, + ) + state["transport_mode"] = "default" + state["tensor_data"] = self.as_subclass(torch.Tensor) + else: + state["transport_mode"] = "default" + state["tensor_data"] = self.as_subclass(torch.Tensor) + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + self._transport_mode = state["transport_mode"] + if state["transport_mode"] == "cuda_ipc" and state["ipc_extra"] is not None: + h = state["ipc_extra"] + target = torch.device(f"cuda:{h['device_index']}") + try: + with torch.cuda.device(target): + storage = torch.UntypedStorage._new_shared_cuda(*h["handle"]) + reconstructed = torch.empty( + 0, dtype=h["dtype"], device=target + ).set_( + storage, + storage_offset=h["storage_offset"], + size=h["shape"], + stride=h["stride"], + ) + self.set_(reconstructed) + except Exception as exc: + logger.error("TransportProxyTensor: failed to open IPC handle: %s", exc) + raise + elif state["tensor_data"] is not None: + self.set_(state["tensor_data"]) + else: + raise RuntimeError("TransportProxyTensor: invalid state – no tensor data") + + @property + def transport_mode(self) -> TensorTransportMode: + return getattr(self, "_transport_mode", "default") + + +# --------------------------------------------------------------------------- +# Helpers: wrap / unwrap mm_inputs dicts +# --------------------------------------------------------------------------- + + +def wrap_mm_inputs_for_ipc( + mm_inputs: Optional[Dict[str, Any]], + transport_mode: TensorTransportMode, + pool: Optional["MmItemMemoryPool"] = None, +) -> Optional[Dict[str, Any]]: + """Recursively wrap CUDA tensors in *mm_inputs* for IPC transport. + + Args: + mm_inputs: Nested dict/list of tensors and other data. + transport_mode: One of ``"default"``, ``"cuda_ipc"``, ``"cuda_ipc_pool"``. + pool: Required when ``transport_mode == "cuda_ipc_pool"``. + + Returns: + A new data structure with CUDA tensors replaced by IPC proxies. + CPU tensors are left unchanged (they will be shared via ``share_memory_()`` + or normal pickling downstream). + """ + if mm_inputs is None: + return None + return _wrap_recursive(mm_inputs, transport_mode, pool) + + +def _wrap_recursive( + data: Any, + transport_mode: TensorTransportMode, + pool: Optional["MmItemMemoryPool"], +) -> Any: + if isinstance(data, torch.Tensor) and data.is_cuda: + return _wrap_cuda_tensor(data, transport_mode, pool) + elif isinstance(data, dict): + return {k: _wrap_recursive(v, transport_mode, pool) for k, v in data.items()} + elif isinstance(data, (list, tuple)): + wrapped = [_wrap_recursive(item, transport_mode, pool) for item in data] + return type(data)(wrapped) + else: + return data + + +def _wrap_cuda_tensor( + tensor: torch.Tensor, + transport_mode: TensorTransportMode, + pool: Optional["MmItemMemoryPool"], +) -> Any: + if transport_mode == "cuda_ipc": + return TransportProxyTensor(tensor, transport_mode="cuda_ipc") + + if transport_mode == "cuda_ipc_pool": + if pool is None: + raise ValueError("pool must be provided for transport_mode='cuda_ipc_pool'") + sync_meta, pool_slice = pool.get_slice_with_flag(tensor) + if pool_slice is not None: + # Copy tensor bytes into the pool slice + pool_slice.copy_(tensor.view(torch.int8).view(-1), non_blocking=True) + return CudaIpcTensorTransportProxy( + data=pool_slice, + info_data=tensor, + sync_buffer_meta=sync_meta, + ) + else: + # Pool full – fall back to simple IPC (with potential memory leak) + logger.warning( + "Pool full; falling back to simple CUDA IPC (potential memory leak)" + ) + return TransportProxyTensor(tensor, transport_mode="cuda_ipc") + + # "default" – move to CPU shared memory (handled by share_memory_() downstream) + return tensor + + +def unwrap_mm_inputs_from_ipc( + mm_inputs: Optional[Dict[str, Any]], + device_index: Optional[int] = None, +) -> Optional[Dict[str, Any]]: + """Recursively reconstruct tensors from IPC proxy objects. + + Call this in the *receiver* process after getting data from the queue. + + Args: + mm_inputs: Data structure possibly containing IPC proxy objects. + device_index: If not None, move reconstructed tensors to this device. + """ + if mm_inputs is None: + return None + return _unwrap_recursive(mm_inputs, device_index) + + +def _unwrap_recursive(data: Any, device_index: Optional[int]) -> Any: + if isinstance(data, CudaIpcTensorTransportProxy): + return data.reconstruct_on_device(device_index) + elif isinstance(data, TransportProxyTensor): + # Already reconstructed during unpickling; just return as plain tensor + return data.as_subclass(torch.Tensor) + elif isinstance(data, dict): + return {k: _unwrap_recursive(v, device_index) for k, v in data.items()} + elif isinstance(data, (list, tuple)): + result = [_unwrap_recursive(item, device_index) for item in data] + return type(data)(result) + else: + return data diff --git a/pymllm/orchestrator/detokenizer_process.py b/pymllm/orchestrator/detokenizer_process.py new file mode 100644 index 000000000..c2154e447 --- /dev/null +++ b/pymllm/orchestrator/detokenizer_process.py @@ -0,0 +1,194 @@ +""" +DetokenizerProcess -- subprocess that converts token IDs back to text. + +Receives ``BatchTokenIDOut``-style dicts from the SchedulerProcess, +detokenizes them, and forwards the decoded strings to the +RequestResponseProcess. +""" + +import logging +from multiprocessing.connection import Connection +from typing import Any, Dict, List, Optional + +import zmq + +from pymllm.orchestrator.ipc_utils import create_zmq_socket, setup_subprocess_logging + +logger = logging.getLogger(__name__) + + +class DetokenizerProcess: + """Runs inside a subprocess. Detokenizes finished outputs.""" + + def __init__( + self, + recv_from_scheduler_addr: str, + send_to_rr_addr: str, + tokenizer_cfg: Optional[Dict[str, Any]] = None, + ): + self._recv_from_scheduler_addr = recv_from_scheduler_addr + self._send_to_rr_addr = send_to_rr_addr + self._tokenizer_cfg = tokenizer_cfg or {} + + self._zmq_ctx: Optional[zmq.Context] = None + self._recv_from_scheduler: Optional[zmq.Socket] = None + self._send_to_rr: Optional[zmq.Socket] = None + + self._tokenizer = None + # Track previous decoded text per rid for incremental (delta) output + self._rid_to_prev_text: Dict[str, str] = {} + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def init_sockets(self) -> None: + self._zmq_ctx = zmq.Context() + self._recv_from_scheduler = create_zmq_socket( + self._zmq_ctx, + zmq.PULL, + self._recv_from_scheduler_addr, + bind=False, + ) + self._send_to_rr = create_zmq_socket( + self._zmq_ctx, + zmq.PUSH, + self._send_to_rr_addr, + bind=False, + ) + + def init_tokenizer(self) -> None: + """Load the tokenizer from the configured path.""" + tokenizer_path = self._tokenizer_cfg.get("tokenizer_path") + if tokenizer_path is None: + logger.warning( + "No tokenizer_path in tokenizer_cfg; detokenization disabled" + ) + return + + from transformers import AutoTokenizer + + trust_remote_code = self._tokenizer_cfg.get("trust_remote_code", False) + self._tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, + trust_remote_code=trust_remote_code, + ) + logger.info("Detokenizer loaded tokenizer from %s", tokenizer_path) + + def event_loop(self) -> None: + """Infinite loop: recv token IDs -> detokenize -> send text to RR.""" + logger.info("DetokenizerProcess event loop started") + while True: + token_id_out = self._recv_from_scheduler.recv_pyobj() + results = self._detokenize(token_id_out) + for result in results: + self._send_to_rr.send_pyobj(result) + + # ------------------------------------------------------------------ + # Detokenization + # ------------------------------------------------------------------ + + def _detokenize(self, token_id_out: Dict[str, Any]) -> List[Dict[str, Any]]: + """Convert token IDs to text and fan out one result per rid. + + The scheduler sends a batch dict with parallel lists keyed by + ``"rids"``, ``"output_ids"``, ``"finished_reasons"``, etc. + This method decodes each rid's output_ids and produces one result + dict per rid with keys ``"rid"`` (singular) and ``"finished"`` + (bool) as expected by ``RequestResponseProcess._recv_loop``. + """ + rids: List[str] = token_id_out.get("rids", []) + output_ids: List[int] = token_id_out.get("output_ids", []) + finished_reasons: List[Optional[str]] = token_id_out.get("finished_reasons", []) + decode_ids: List[int] = token_id_out.get("decode_ids", []) + skip_special_tokens_list: List[bool] = token_id_out.get( + "skip_special_tokens", [] + ) + prompt_tokens_list: List[int] = token_id_out.get("prompt_tokens", []) + completion_tokens_list: List[int] = token_id_out.get("completion_tokens", []) + + results: List[Dict[str, Any]] = [] + + for i, rid in enumerate(rids): + finished_reason = finished_reasons[i] if i < len(finished_reasons) else None + is_finished = finished_reason is not None + skip_special = ( + skip_special_tokens_list[i] + if i < len(skip_special_tokens_list) + else True + ) + prompt_tokens = prompt_tokens_list[i] if i < len(prompt_tokens_list) else 0 + completion_tokens = ( + completion_tokens_list[i] if i < len(completion_tokens_list) else 0 + ) + + # Decode text from output_ids + if self._tokenizer is not None: + text = self._tokenizer.decode( + output_ids, + skip_special_tokens=skip_special, + ) + else: + text = "" + + # Compute incremental delta by diffing against previous text + prev_text = self._rid_to_prev_text.get(rid, "") + delta_text = text[len(prev_text):] + self._rid_to_prev_text[rid] = text + + # Clean up tracking when request finishes + if is_finished: + self._rid_to_prev_text.pop(rid, None) + + result: Dict[str, Any] = { + "rid": rid, + "text": text, + "delta": delta_text, + "output_token_ids": list(output_ids), + "finished": is_finished, + "finished_reason": finished_reason, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + } + results.append(result) + + return results + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def shutdown(self) -> None: + if self._recv_from_scheduler is not None: + self._recv_from_scheduler.close() + if self._send_to_rr is not None: + self._send_to_rr.close() + if self._zmq_ctx is not None: + self._zmq_ctx.term() + + +def run_detokenizer_process( + recv_from_scheduler_addr: str, + send_to_rr_addr: str, + pipe_writer: Connection, + tokenizer_cfg: Optional[Dict[str, Any]] = None, +) -> None: + """Entry point for ``torch.multiprocessing.Process(target=...)``.""" + setup_subprocess_logging((tokenizer_cfg or {}).get("log_level", "info")) + proc = DetokenizerProcess( + recv_from_scheduler_addr, + send_to_rr_addr, + tokenizer_cfg=tokenizer_cfg, + ) + proc.init_sockets() + proc.init_tokenizer() + + pipe_writer.send({"status": "ready", "process": "detokenizer"}) + pipe_writer.close() + + try: + proc.event_loop() + except KeyboardInterrupt: + pass + finally: + proc.shutdown() diff --git a/pymllm/orchestrator/group_coordinator.py b/pymllm/orchestrator/group_coordinator.py new file mode 100644 index 000000000..2fec30784 --- /dev/null +++ b/pymllm/orchestrator/group_coordinator.py @@ -0,0 +1,104 @@ +"""GroupCoordinator for distributed communication.""" + +from typing import List +import torch +import torch.distributed as dist + + +class GroupCoordinator: + """Manages a group of processes for distributed communication. + + Lightweight wrapper around torch.distributed.ProcessGroup. + + Args: + ranks: List of global ranks in this group + local_rank: Local rank for device assignment + backend: Backend to use (nccl, gloo, etc.) + """ + + def __init__( + self, + ranks: List[int], + local_rank: int, + backend: str = "nccl", + ): + self.ranks = ranks + self.local_rank = local_rank + self.backend = backend + self.world_size = len(ranks) + + # Get rank in this specific group + self.rank_in_group = ranks.index(dist.get_rank()) if dist.is_initialized() else 0 + + # Create process group + if dist.is_initialized() and self.world_size > 1: + self.device_group = dist.new_group(ranks, backend=backend) + else: + self.device_group = None + + def all_reduce(self, tensor: torch.Tensor) -> torch.Tensor: + """All-reduce across the group.""" + if self.device_group is not None: + dist.all_reduce(tensor, group=self.device_group) + return tensor + + def all_gather(self, tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: + """All-gather across the group.""" + if self.device_group is None: + return tensor + + world_size = self.world_size + if dim == 0: + shape = list(tensor.shape) + shape[0] = shape[0] * world_size + output = torch.empty(shape, dtype=tensor.dtype, device=tensor.device) + dist.all_gather_into_tensor(output, tensor, group=self.device_group) + return output + else: + # For non-dim-0 gathers, use tensor list + tensor_list = [ + torch.empty_like(tensor) for _ in range(world_size) + ] + dist.all_gather(tensor_list, tensor, group=self.device_group) + return torch.cat(tensor_list, dim=dim) + + def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """Broadcast from source rank to all. + + Args: + tensor: Tensor to broadcast. + src: Source rank relative to this group (0 <= src < world_size). + """ + if self.device_group is not None: + global_src = self.ranks[src] + dist.broadcast(tensor, src=global_src, group=self.device_group) + return tensor + + +def divide(numerator: int, denominator: int) -> int: + """Divide and ensure divisibility.""" + assert numerator % denominator == 0, ( + f"{numerator} is not divisible by {denominator}" + ) + return numerator // denominator + + +def split_tensor_along_dim( + tensor: torch.Tensor, + dim: int, + world_size: int, + rank: int, +) -> torch.Tensor: + """Split tensor along a dimension for tensor parallelism.""" + dim_size = tensor.size(dim) + assert dim_size % world_size == 0, ( + f"Dimension {dim} ({dim_size}) not divisible by world_size {world_size}" + ) + + chunk_size = dim_size // world_size + start = rank * chunk_size + end = start + chunk_size + + slices = [slice(None)] * tensor.ndim + slices[dim] = slice(start, end) + return tensor[tuple(slices)] diff --git a/pymllm/orchestrator/ipc_utils.py b/pymllm/orchestrator/ipc_utils.py new file mode 100644 index 000000000..b464a3979 --- /dev/null +++ b/pymllm/orchestrator/ipc_utils.py @@ -0,0 +1,92 @@ +"""ZMQ IPC utilities for inter-process communication. + +Provides helpers to generate unique IPC addresses and create pre-configured +ZMQ sockets so that every process uses the same conventions. +""" + +import logging +import os +import tempfile +from typing import Optional + +import zmq + + +_IPC_DIR = os.path.join(tempfile.gettempdir(), "pymllm_ipc") + + +def _ensure_ipc_dir() -> None: + os.makedirs(_IPC_DIR, exist_ok=True) + + +def make_ipc_address(name: str, unique_id: Optional[str] = None) -> str: + """Return an ``ipc://`` address for *name*, optionally scoped by *unique_id*. + + Parameters + ---------- + name + Logical channel name, e.g. ``"rr_to_tokenizer"``. + unique_id + Per-engine identifier (typically ``str(os.getpid())``) to avoid + collisions when multiple engines run on the same host. + """ + _ensure_ipc_dir() + suffix = f"_{unique_id}" if unique_id else "" + return f"ipc://{_IPC_DIR}/pymllm_{name}{suffix}" + + +def create_zmq_socket( + ctx: zmq.Context, + socket_type: int, + address: str, + bind: bool, +) -> zmq.Socket: + """Create a ZMQ socket, bind or connect it, and return it. + + Parameters + ---------- + ctx + A ``zmq.Context`` shared within the process. + socket_type + One of ``zmq.PUSH``, ``zmq.PULL``, ``zmq.PAIR``, etc. + address + The ``ipc://`` address string. + bind + If ``True`` the socket calls ``bind``; otherwise ``connect``. + """ + sock = ctx.socket(socket_type) + sock.setsockopt(zmq.LINGER, 0) + if bind: + sock.bind(address) + else: + sock.connect(address) + return sock + + +def close_zmq_socket(sock: zmq.Socket) -> None: + """Close a ZMQ socket, ignoring errors.""" + try: + sock.close() + except zmq.ZMQError: + pass + + +def setup_subprocess_logging(log_level: str = "info") -> None: + """Configure logging for a spawned subprocess. + + When Python spawns a subprocess (``mp.set_start_method('spawn')``), the + child starts with a blank logging configuration. Call this function at the + very beginning of every subprocess entry point so that log records are + emitted at the correct level. + + Parameters + ---------- + log_level + Case-insensitive level name, e.g. ``"debug"``, ``"info"``, ``"warning"``. + """ + level = getattr(logging, log_level.upper(), logging.INFO) + logging.basicConfig( + level=level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logging.getLogger("pymllm").setLevel(level) diff --git a/pymllm/orchestrator/model_runner_process.py b/pymllm/orchestrator/model_runner_process.py new file mode 100644 index 000000000..d850dd53e --- /dev/null +++ b/pymllm/orchestrator/model_runner_process.py @@ -0,0 +1,968 @@ +""" +ModelRunnerProcess -- GPU-owning component that executes model forward passes. + +Instantiated **in-process** by :class:`SchedulerProcess` (sglang-style +architecture). The scheduler calls :meth:`_forward_batch` directly — +no inter-process communication is involved. + +This component owns the GPU: it holds a :class:`ModelRunner` with model +weights, KV-cache memory pools, and the attention backend. It also owns +the :class:`RadixCache` for prefix-aware KV reuse. + +RadixCache lifecycle +-------------------- +1. **match_prefix** — called during ``_allocate_extend`` before KV allocation. +2. **inc_lock_ref** — locks matched radix-tree nodes to prevent eviction. +3. **insert (prefill)** — inserts prompt KV indices after prefill. +4. **insert (completion)** — re-inserts the full sequence when a request finishes. +5. **dec_lock_ref** — unlocks radix-tree nodes when a request is freed. +6. **evict** — called when KV allocation fails to free stale cache entries. +""" + +import logging +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from pymllm.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode + +logger = logging.getLogger(__name__) + +# Fraction of KV pool to try evicting when allocation fails. +_EVICT_FRACTION = 0.10 +# Maximum number of eviction retries before giving up. +_MAX_EVICT_RETRIES = 3 + + +class ModelRunnerProcess: + """GPU-owning component created in-process by SchedulerProcess.""" + + def __init__( + self, + gpu_id: int = 0, + server_config: Optional[Any] = None, + model_config: Optional[Any] = None, + ): + self._gpu_id = gpu_id + self._server_config = server_config + self._model_config = model_config + + # The ModelRunner instance (created in init_model) + self._runner = None + self._is_hybrid: bool = False + + # RadixCache instance (created in init_model, after memory pools) + self._radix_cache: Optional[RadixCache] = None + + # GPU resource tracking: maps rid -> req_pool_idx (slot in ReqToTokenPool) + self._rid_to_req_pool_idx: Dict[str, int] = {} + # Maps rid -> kv_indices tensor (all KV-cache token indices for this request) + self._rid_to_kv_indices: Dict[str, torch.Tensor] = {} + # Maps rid -> input_ids used for prefill (needed for radix cache insert) + self._rid_to_input_ids: Dict[str, List[int]] = {} + # Maps rid -> list of generated (decode) token ids, appended each step. + # Used to build the full sequence for radix cache insert at completion. + self._rid_to_output_ids: Dict[str, List[int]] = {} + # Maps rid -> cache_protected_len: the length of the prefix that has + # already been inserted into the radix cache. When insert() returns + # prefix_len > cache_protected_len, the KV indices in the overlap + # range [cache_protected_len, prefix_len) are duplicates that must + # be freed from the allocator (the tree already holds cloned copies). + self._rid_to_cache_protected_len: Dict[str, int] = {} + # Maps rid -> (last_node, swa_boundary_id) for radix cache lock tracking + self._rid_to_radix_lock: Dict[str, Tuple[TreeNode, Optional[int]]] = {} + # Maps rid -> mrope_position_delta (M-RoPE positional offset per request) + # Populated during prefill; used to offset decode-step positions for + # multimodal models (Qwen3-VL) that consume more position indices than + # tokens due to 3-D image grid positions. + self._rid_to_mrope_delta: Dict[str, int] = {} + + # GDN prefix cache state tracking (hybrid models only): + # Maps rid -> GDN track slot index in GDNPool (for snapshotting state) + self._rid_to_gdn_track_slot: Dict[str, int] = {} + # Maps radix tree node id -> GDN track slot index + self._node_id_to_gdn_track_slot: Dict[int, int] = {} + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def init_model(self) -> None: + """Create and initialise the ModelRunner and RadixCache. + + Must run inside the subprocess (after spawn) since it does CUDA init. + """ + from pymllm.executor.model_runner import ModelRunner + + logger.info( + "ModelRunnerProcess: initialising ModelRunner on GPU %d", + self._gpu_id, + ) + self._runner = ModelRunner( + server_config=self._server_config, + model_config=self._model_config, + gpu_id=self._gpu_id, + ) + self._runner.initialize() + + # Initialise RadixCache after memory pools are ready. + disable_cache = getattr(self._server_config, "disable_radix_cache", False) + self._is_hybrid = self._runner.num_gdn_layers > 0 + if self._is_hybrid and not disable_cache: + logger.info( + "ModelRunnerProcess: prefix caching ENABLED with GDN state " + "tracking (%d GDN layers)", + self._runner.num_gdn_layers, + ) + sliding_window = self._runner.sliding_window_size + page_size = getattr(self._server_config, "radix_cache_page_size", 1) + # For hybrid models, register an eviction callback so that evicted + # radix nodes free their associated GDN track slots. + evict_cb = self._on_radix_node_evict if self._is_hybrid else None + self._radix_cache = RadixCache( + page_size=page_size, + sliding_window_size=sliding_window, + disable=disable_cache, + token_to_kv_pool_allocator=self._runner.token_to_kv_pool_allocator, + on_node_evict=evict_cb, + ) + logger.info( + "ModelRunnerProcess: RadixCache initialized " + "(disable=%s, sliding_window=%s)", + disable_cache, + sliding_window, + ) + logger.info("ModelRunnerProcess: ModelRunner ready") + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + def _forward_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: + """Run the model forward pass and sampling for *batch*. + + *batch* is a dict produced by ``ScheduleBatch.to_batch_dict()`` + containing ``"forward_mode"``, ``"input_ids"``, ``"seq_lens"``, + ``"req_pool_indices"``, ``"requests"`` (metadata list), etc. + + Implements 6 phases: + 1. Cleanup: free GPU resources for rids no longer in the batch + 2. Prefix matching + KV allocation + 3. Build GPU tensors + 4. Forward + sample + 5. Radix cache insert (extend only) + 6. Build result dict + """ + runner = self._runner + forward_mode = batch.get("forward_mode", "decode") + batch_size = batch.get("batch_size", 0) + requests_meta: List[Dict[str, Any]] = batch.get("requests", []) + + if batch_size == 0: + return {"batch_id": batch.get("batch_id"), "outputs": []} + + device = runner.device + + # Collect current batch rids + current_rids: Set[str] = {m["rid"] for m in requests_meta} + + # ============================================================== + # Phase 2: Prefix matching + KV allocation + # ============================================================== + # For extend batches, match_prefix is done inside _allocate_extend + # which may update extend_prefix_lens and extend_seq_lens. + if forward_mode == "extend": + out_cache_loc, actual_prefix_lens, actual_extend_lens = ( + self._allocate_extend(batch, requests_meta) + ) + else: + out_cache_loc = self._allocate_decode(batch, requests_meta) + actual_prefix_lens = None + actual_extend_lens = None + + # ============================================================== + # Phase 3: Build GPU tensors + # ============================================================== + if forward_mode == "extend" and actual_prefix_lens is not None: + # Rebuild input_ids and seq_lens using actual prefix matches. + # The scheduler sent tokens assuming prefix_len=0; we need to + # trim the input_ids to skip the prefix-matched tokens. + ( + input_ids_tensor, + seq_lens_tensor, + extend_seq_lens_t, + extend_prefix_lens_t, + ) = self._rebuild_extend_tensors( + batch, requests_meta, actual_prefix_lens, actual_extend_lens, device + ) + else: + input_ids_list: List[int] = batch["input_ids"] + seq_lens_list: List[int] = batch["seq_lens"] + input_ids_tensor = torch.tensor( + input_ids_list, dtype=torch.int32, device=device + ) + seq_lens_tensor = torch.tensor( + seq_lens_list, dtype=torch.int32, device=device + ) + extend_seq_lens_t = None + extend_prefix_lens_t = None + + # Build req_pool_indices from our own tracking (NOT from scheduler) + req_pool_indices = torch.tensor( + [self._rid_to_req_pool_idx[m["rid"]] for m in requests_meta], + dtype=torch.int64, + device=device, + ) + + out_cache_loc = out_cache_loc.to(torch.int64) + + # ============================================================== + # Phase 4: Forward + sample + # ============================================================== + # Extract per-request sampling params + temperatures = [] + top_ps = [] + top_ks = [] + for m in requests_meta: + sp = m.get("sampling_params") or {} + temperatures.append(sp.get("temperature", 1.0)) + top_ps.append(sp.get("top_p", 1.0)) + top_ks.append(sp.get("top_k", -1)) + + temps_tensor = torch.tensor(temperatures, dtype=torch.float32, device=device) + top_ps_tensor = torch.tensor(top_ps, dtype=torch.float32, device=device) + top_ks_tensor = torch.tensor(top_ks, dtype=torch.int32, device=device) + + if forward_mode == "extend": + if extend_seq_lens_t is None: + extend_seq_lens_list: List[int] = batch["extend_seq_lens"] + extend_prefix_lens_list: List[int] = batch["extend_prefix_lens"] + extend_seq_lens_t = torch.tensor( + extend_seq_lens_list, dtype=torch.int32, device=device + ) + extend_prefix_lens_t = torch.tensor( + extend_prefix_lens_list, dtype=torch.int32, device=device + ) + + fb = runner.prepare_forward_batch_extend( + input_ids=input_ids_tensor, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens_tensor, + extend_seq_lens=extend_seq_lens_t, + extend_prefix_lens=extend_prefix_lens_t, + out_cache_loc=out_cache_loc, + ) + + # Attach multimodal vision inputs to ForwardBatch so the + # model's vision encoder can process images during prefill. + # The tokenizer wraps processor output under "image_inputs"; + # fall back to top-level keys for direct dicts. + pixel_values_list = [] + image_grid_thw_list = [] + for m in requests_meta: + mm = m.get("mm_inputs") + if mm is None: + continue + # AutoProcessor output is nested under "image_inputs" + src = mm.get("image_inputs") if "image_inputs" in mm else mm + if src is None: + continue + pv = src.get("pixel_values") if hasattr(src, "get") else getattr(src, "pixel_values", None) + thw = src.get("image_grid_thw") if hasattr(src, "get") else getattr(src, "image_grid_thw", None) + if pv is not None: + if not isinstance(pv, torch.Tensor): + pv = torch.as_tensor(pv) + pixel_values_list.append(pv.to(device=device)) + if thw is not None: + if not isinstance(thw, torch.Tensor): + thw = torch.as_tensor(thw) + image_grid_thw_list.append(thw.to(device=device)) + if pixel_values_list: + fb.pixel_values = torch.cat(pixel_values_list, dim=0) + if image_grid_thw_list: + fb.image_grid_thw = torch.cat(image_grid_thw_list, dim=0) + else: + # Build mrope_position_deltas tensor for decode batches. + mrope_deltas = [ + self._rid_to_mrope_delta.get(m["rid"], 0) for m in requests_meta + ] + mrope_deltas_tensor = torch.tensor( + mrope_deltas, dtype=torch.int64, device=device + ) + + fb = runner.prepare_forward_batch_decode( + input_ids=input_ids_tensor, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens_tensor, + out_cache_loc=out_cache_loc, + mrope_position_deltas=mrope_deltas_tensor, + ) + + logits_output = runner.forward(fb) + + # Persist M-RoPE position deltas for multimodal models (Qwen3-VL). + # The model sets mrope_position_deltas on the ForwardBatch during + # prefill; we store them here so decode steps can retrieve them. + if ( + forward_mode == "extend" + and getattr(fb, "mrope_position_deltas", None) is not None + ): + deltas_cpu = fb.mrope_position_deltas.cpu().tolist() + for idx, m in enumerate(requests_meta): + self._rid_to_mrope_delta[m["rid"]] = int(deltas_cpu[idx]) + + next_token_ids = runner.sample( + logits_output, + fb, + temperatures=temps_tensor, + top_ps=top_ps_tensor, + top_ks=top_ks_tensor, + ) + + # ============================================================== + # Phase 4.5: Snapshot GDN state after extend (hybrid models) + # ============================================================== + if forward_mode == "extend" and self._is_hybrid: + self._track_gdn_state_after_extend(requests_meta) + + # ============================================================== + # Phase 5: Radix cache insert (extend only) + # ============================================================== + if forward_mode == "extend" and self._radix_cache is not None: + self._insert_into_radix_cache(requests_meta) + + # ============================================================== + # Phase 6: Build result & track output tokens + # ============================================================== + next_ids_cpu = next_token_ids.cpu().tolist() + outputs: List[Dict[str, Any]] = [] + for i, m in enumerate(requests_meta): + rid = m["rid"] + token_id = next_ids_cpu[i] if i < len(next_ids_cpu) else 0 + # Track output tokens for radix cache insert at completion + out_ids = self._rid_to_output_ids.get(rid) + if out_ids is not None: + out_ids.append(token_id) + + out: Dict[str, Any] = { + "rid": rid, + "output_token_ids": [token_id], + } + # Report actual prefix_len back to the scheduler so it can + # update its token budget tracking accurately. + if actual_prefix_lens is not None: + out["prefix_len"] = actual_prefix_lens[i] + outputs.append(out) + + return { + "batch_id": batch.get("batch_id"), + "outputs": outputs, + } + + # ------------------------------------------------------------------ + # Tensor rebuild for prefix-matched extend + # ------------------------------------------------------------------ + + def _rebuild_extend_tensors( + self, + batch: Dict[str, Any], + requests_meta: List[Dict[str, Any]], + actual_prefix_lens: List[int], + actual_extend_lens: List[int], + device: str, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Rebuild input_ids and related tensors after prefix matching. + + The scheduler sent input_ids assuming no prefix cache hit. After + radix cache matching, we know the actual prefix lengths and must + trim the input_ids accordingly. + + Returns (input_ids, seq_lens, extend_seq_lens, extend_prefix_lens) + as GPU tensors. + """ + # Reconstruct trimmed input_ids: for each request, take only the + # tokens beyond the matched prefix. + new_input_ids: List[int] = [] + seq_lens_list: List[int] = batch["seq_lens"] + + for i, m in enumerate(requests_meta): + full_input_ids = m.get("input_ids", []) + prefix_len = actual_prefix_lens[i] + # Only send tokens after the prefix + new_input_ids.extend(full_input_ids[prefix_len:]) + + input_ids = torch.tensor(new_input_ids, dtype=torch.int32, device=device) + seq_lens = torch.tensor(seq_lens_list, dtype=torch.int32, device=device) + extend_seq_lens = torch.tensor( + actual_extend_lens, dtype=torch.int32, device=device + ) + extend_prefix_lens = torch.tensor( + actual_prefix_lens, dtype=torch.int32, device=device + ) + return input_ids, seq_lens, extend_seq_lens, extend_prefix_lens + + # ------------------------------------------------------------------ + # Radix cache insert + # ------------------------------------------------------------------ + + def _insert_into_radix_cache(self, requests_meta: List[Dict[str, Any]]) -> None: + """Insert prefill KV indices into the radix cache for future reuse. + + Mirrors sglang's ``cache_unfinished_req`` pattern: + + 1. **Insert** the request's token → KV index mapping into the tree. + 2. **Free duplicates** — indices in ``[cache_protected_len, new_prefix_len)`` + are now owned by the tree; the request's copies are redundant. + 3. **Re-match + write-back** — fetch the tree's *own* indices via + ``match_prefix`` and write them into ``req_to_token_pool``, + replacing the just-freed entries. Without this step the pool + still points at freed slots → use-after-free during decode. + 4. **Update** ``cache_protected_len`` and radix lock. + """ + cache = self._radix_cache + if cache is None or cache.disable: + return + + runner = self._runner + gdn_pool = getattr(runner, "gdn_pool", None) + + for m in requests_meta: + rid = m["rid"] + input_ids = self._rid_to_input_ids.get(rid) + if input_ids is None: + continue + + slot = self._rid_to_req_pool_idx.get(rid) + if slot is None: + continue + + seq_len = len(input_ids) + kv_indices = runner.req_to_token_pool.req_to_token[slot, :seq_len].to( + torch.int64 + ) + + key = RadixKey(input_ids) + result = cache.insert(key, kv_indices) + new_prefix_len = result.prefix_len + + # --- Step 2: free duplicates --- + cache_protected_len = self._rid_to_cache_protected_len.get(rid, 0) + if new_prefix_len > cache_protected_len: + dup_indices = kv_indices[cache_protected_len:new_prefix_len] + if dup_indices.numel() > 0: + runner.token_to_kv_pool_allocator.free(dup_indices) + + # --- Step 3: re-match + write-back --- + # The tree now owns indices for [0, new_prefix_len). Fetch them + # and patch req_to_token_pool so the request reads the tree's + # (still-live) indices instead of the freed ones. + rematch = cache.match_prefix(key) + new_indices = rematch.indices + if len(new_indices) > cache_protected_len: + runner.req_to_token_pool.write( + (slot, slice(cache_protected_len, len(new_indices))), + new_indices[cache_protected_len:].to(torch.int32), + ) + + # --- Step 4: update tracking --- + self._rid_to_cache_protected_len[rid] = len(new_indices) + + # Update radix lock to cover the new (potentially deeper) node. + old_lock = self._rid_to_radix_lock.pop(rid, None) + if old_lock is not None: + old_node, old_swa = old_lock + cache.dec_lock_ref(old_node, old_swa) + new_last_node = rematch.last_node + if new_last_node is not None and len(new_indices) > 0: + swa_id = cache.inc_lock_ref(new_last_node) + self._rid_to_radix_lock[rid] = (new_last_node, swa_id) + + # --- GDN track slot association (hybrid models) --- + if gdn_pool is not None and result.last_node is not None: + track_slot = self._rid_to_gdn_track_slot.get(rid) + if track_slot is not None: + node_id = result.last_node.id + old_ts = self._node_id_to_gdn_track_slot.get(node_id) + if old_ts is None: + self._node_id_to_gdn_track_slot[node_id] = track_slot + else: + gdn_pool.free_track_slot(track_slot) + self._rid_to_gdn_track_slot.pop(rid, None) + + # ------------------------------------------------------------------ + # KV allocation helpers + # ------------------------------------------------------------------ + + def _allocate_extend( + self, batch: Dict[str, Any], requests_meta: List[Dict[str, Any]] + ) -> Tuple[torch.Tensor, List[int], List[int]]: + """Allocate req pool slots and KV tokens for an extend (prefill) batch. + + Performs radix cache prefix matching before allocation: + 1. For each request, call ``match_prefix`` to find cached KV indices. + 2. Write cached indices into ``ReqToTokenPool``. + 3. Only allocate new KV tokens for the non-cached suffix. + 4. Lock matched radix nodes to prevent eviction. + + Returns ``(out_cache_loc, actual_prefix_lens, actual_extend_lens)``. + ``out_cache_loc`` has shape ``[total_new_tokens]``. + """ + runner = self._runner + cache = self._radix_cache + batch_size = batch["batch_size"] + seq_lens: List[int] = batch["seq_lens"] + + # --- Step 1: Radix cache prefix matching --- + actual_prefix_lens: List[int] = [] + actual_extend_lens: List[int] = [] + matched_nodes: List[Optional[TreeNode]] = [] + # Cache the match results so we don't call match_prefix twice + cached_indices_list: List[Optional[torch.Tensor]] = [] + gdn_pool = getattr(runner, "gdn_pool", None) + + for i, m in enumerate(requests_meta): + full_input_ids: List[int] = m.get("input_ids", []) + full_seq_len = seq_lens[i] + + # Store input_ids for later radix cache insert + self._rid_to_input_ids[m["rid"]] = full_input_ids + + if cache is not None and not cache.disable and len(full_input_ids) > 0: + key = RadixKey(full_input_ids) + match_result = cache.match_prefix(key) + prefix_len = match_result.prefix_len + last_node = match_result.last_node + cached_indices = match_result.indices + else: + prefix_len = 0 + last_node = None + cached_indices = None + + # Hybrid model guard: only use a KV cache hit if the matched + # node has a GDN state snapshot. Without it, the full-attention + # layers would use cached KV while GDN layers start from zero, + # causing an attention/GDN state mismatch. Discard the hit so + # the entire prompt is processed from scratch. + if ( + gdn_pool is not None + and prefix_len > 0 + and last_node is not None + and self._node_id_to_gdn_track_slot.get(last_node.id) is None + ): + logger.debug( + "Discarding radix cache hit for rid=%s: no GDN state " + "for matched node (prefix_len=%d)", + m["rid"], prefix_len, + ) + prefix_len = 0 + last_node = None + cached_indices = None + + # Ensure at least 1 token is extended (not fully cached). + # A full cache hit (prefix_len == full_seq_len) would produce a + # 0-length input tensor that crashes CUDA kernels. Back off by 1 + # so the model always sees the last token. + if prefix_len >= full_seq_len: + prefix_len = full_seq_len - 1 + if cached_indices is not None: + cached_indices = cached_indices[:prefix_len] + + extend_len = full_seq_len - prefix_len + actual_prefix_lens.append(prefix_len) + actual_extend_lens.append(extend_len) + matched_nodes.append(last_node) + cached_indices_list.append(cached_indices) + + if prefix_len > 0: + logger.info( + "Radix cache hit for rid=%s: %d/%d tokens reused (%.1f%%)", + m["rid"], + prefix_len, + full_seq_len, + 100.0 * prefix_len / full_seq_len, + ) + + total_new_tokens = sum(actual_extend_lens) + + # --- Step 2: Allocate req pool slots --- + slots = runner.req_to_token_pool.alloc(batch_size) + if slots is None: + raise RuntimeError("Failed to allocate req pool slots for extend batch") + + # --- Step 3: Allocate KV tokens (with eviction retry) --- + out_cache_loc = self._alloc_kv_with_eviction(total_new_tokens) + if out_cache_loc is None: + for s in slots: + runner.req_to_token_pool.free(s) + raise RuntimeError( + f"Failed to allocate {total_new_tokens} KV tokens for extend batch " + f"(even after eviction)" + ) + + # --- Step 4: Write indices into req_to_token_pool --- + offset = 0 + for i, m in enumerate(requests_meta): + rid = m["rid"] + slot = slots[i] + prefix_len = actual_prefix_lens[i] + extend_len = actual_extend_lens[i] + full_seq_len = seq_lens[i] + + # Write cached prefix indices (from the match result we saved) + cached_indices = cached_indices_list[i] + if cached_indices is not None and prefix_len > 0: + runner.req_to_token_pool.write( + (slot, slice(0, prefix_len)), + cached_indices[:prefix_len].to(torch.int32), + ) + + # Write new KV indices for the suffix + kv_indices = out_cache_loc[offset : offset + extend_len] + runner.req_to_token_pool.write( + (slot, slice(prefix_len, full_seq_len)), kv_indices + ) + + self._rid_to_req_pool_idx[rid] = slot + self._rid_to_kv_indices[rid] = kv_indices.clone() + self._rid_to_output_ids[rid] = [] + # The prefix portion is already protected in the radix cache + # (from a previous request's insert). We start with this as + # cache_protected_len so that subsequent insert() calls know + # which range is already covered. + self._rid_to_cache_protected_len[rid] = actual_prefix_lens[i] + offset += extend_len + + # GDN state management: restore from track slot on cache hit, or reset + if gdn_pool is not None: + for i, m in enumerate(requests_meta): + rid = m["rid"] + working_slot = slots[i] + prefix_len = actual_prefix_lens[i] + node = matched_nodes[i] + + if prefix_len > 0 and node is not None: + # Cache hit — try to restore GDN state from the track slot + # associated with the matched radix node. + track_slot = self._node_id_to_gdn_track_slot.get(node.id) + if track_slot is not None: + gdn_pool.copy_states(track_slot, working_slot) + logger.debug( + "GDN state restored for rid=%s from track_slot=%d " + "(prefix_len=%d)", + rid, track_slot, prefix_len, + ) + else: + # Cache hit but no GDN snapshot — reset to zero. + # This can happen if the track slot was evicted. + idx = torch.tensor( + [working_slot], dtype=torch.int64, device=runner.device + ) + gdn_pool.reset_states(idx) + logger.debug( + "GDN state reset for rid=%s (cache hit but no " + "track slot, prefix_len=%d)", + rid, prefix_len, + ) + else: + # No cache hit — fresh request, zero-init + idx = torch.tensor( + [working_slot], dtype=torch.int64, device=runner.device + ) + gdn_pool.reset_states(idx) + + # Allocate a track slot only when the radix cache is enabled; + # track slots are freed via the eviction callback so they must + # be associated with a node, which only happens when cache is on. + if cache is not None and not cache.disable: + ts = gdn_pool.alloc_track_slot() + if ts is not None: + self._rid_to_gdn_track_slot[rid] = ts + + # --- Step 5: Lock matched radix nodes --- + if cache is not None and not cache.disable: + for i, m in enumerate(requests_meta): + node = matched_nodes[i] + if node is not None and actual_prefix_lens[i] > 0: + swa_boundary_id = cache.inc_lock_ref(node) + self._rid_to_radix_lock[m["rid"]] = (node, swa_boundary_id) + + return out_cache_loc, actual_prefix_lens, actual_extend_lens + + def _alloc_kv_with_eviction(self, num_tokens: int) -> Optional[torch.Tensor]: + """Try to allocate KV tokens, evicting from radix cache if needed.""" + runner = self._runner + cache = self._radix_cache + + if num_tokens == 0: + return torch.empty(0, dtype=torch.int32, device=runner.device) + + # First attempt: direct allocation + result = runner.token_to_kv_pool_allocator.alloc(num_tokens) + if result is not None: + return result + + # Eviction loop: try evicting from radix cache to free space + if cache is None or cache.disable: + return None + + for attempt in range(_MAX_EVICT_RETRIES): + evictable = cache.evictable_size() + if evictable == 0: + logger.warning( + "KV allocation failed: need %d tokens, no evictable cache entries", + num_tokens, + ) + return None + + # Evict a fraction of the cache (at least what we need) + evict_target = max( + num_tokens, + int(runner.token_to_kv_pool_allocator.size * _EVICT_FRACTION), + ) + evict_result = cache.evict(evict_target) + logger.info( + "Radix cache eviction attempt %d: evicted %d tokens (target=%d)", + attempt + 1, + evict_result.full_evicted, + evict_target, + ) + + # Retry allocation + result = runner.token_to_kv_pool_allocator.alloc(num_tokens) + if result is not None: + return result + + return None + + def _allocate_decode( + self, batch: Dict[str, Any], requests_meta: List[Dict[str, Any]] + ) -> torch.Tensor: + """Allocate 1 KV token per request for a decode step. + + Returns ``out_cache_loc`` tensor of shape ``[batch_size]``. + """ + runner = self._runner + batch_size = batch["batch_size"] + seq_lens: List[int] = batch["seq_lens"] + + # Allocate 1 new KV token per request (with eviction retry) + out_cache_loc = self._alloc_kv_with_eviction(batch_size) + if out_cache_loc is None: + raise RuntimeError( + f"Failed to allocate {batch_size} KV tokens for decode batch" + ) + + # Write the new KV token index into each request's mapping + for i, m in enumerate(requests_meta): + rid = m["rid"] + slot = self._rid_to_req_pool_idx.get(rid) + if slot is None: + logger.warning("Decode step for unknown rid=%s, skipping KV write", rid) + continue + + cur_seq_len = seq_lens[i] + kv_new = out_cache_loc[i : i + 1] + # The scheduler increments req.seq_len by 1 after every step, so + # seq_lens[i] == (number of tokens in the KV cache INCLUDING the + # token being decoded now). The new token's slot must therefore be + # written at index seq_lens[i] - 1, matching the position used by + # prepare_forward_batch_decode (positions = seq_lens - 1) and the + # window FlashInfer reads (req_to_token_pool[slot, 0:seq_lens[i]]). + write_pos = cur_seq_len - 1 + runner.req_to_token_pool.write( + (slot, slice(write_pos, write_pos + 1)), kv_new + ) + + # Append to tracked kv_indices + prev = self._rid_to_kv_indices.get(rid) + if prev is not None: + self._rid_to_kv_indices[rid] = torch.cat([prev, kv_new]) + else: + self._rid_to_kv_indices[rid] = kv_new.clone() + + return out_cache_loc + + # ------------------------------------------------------------------ + # Resource cleanup + # ------------------------------------------------------------------ + + def _free_rid_resources(self, rid: str) -> None: + """Free GPU resources (req pool slot + KV indices) for a finished rid. + + KV index ownership model (when radix cache is enabled): + + ``req_to_token_pool[slot]`` contains three regions after + ``insert()`` returns ``new_prefix_len``:: + + [0, cache_protected_len) + Indices shared with the radix tree from a previous insert. + **Do not free** — the tree already owns them. + + [cache_protected_len, new_prefix_len) + Indices allocated by THIS request that turned out to overlap + with tree nodes inserted concurrently. The tree already + holds cloned copies → these are duplicates → **free them**. + + [new_prefix_len, total_len) + Indices that ``insert()`` just added to the tree (cloned). + The tree now owns the underlying KV pool slots. + **Do not free** — the tree will free during eviction. + + When the radix cache is disabled, all KV indices are freed directly. + """ + runner = self._runner + cache = self._radix_cache + + slot = self._rid_to_req_pool_idx.pop(rid, None) + kv_indices = self._rid_to_kv_indices.pop(rid, None) + input_ids = self._rid_to_input_ids.pop(rid, None) + output_ids = self._rid_to_output_ids.pop(rid, None) + cache_protected_len = self._rid_to_cache_protected_len.pop(rid, 0) + radix_lock = self._rid_to_radix_lock.pop(rid, None) + self._rid_to_mrope_delta.pop(rid, None) + + # Free GDN track slot (if any) — the slot's association with a + # radix node is managed separately via _node_id_to_gdn_track_slot + # and the eviction callback; here we just remove the rid mapping. + self._rid_to_gdn_track_slot.pop(rid, None) + + cache_enabled = cache is not None and not cache.disable + + # ---------------------------------------------------------- + # Phase 1: Read all KV indices BEFORE freeing anything. + # ---------------------------------------------------------- + prompt_len = len(input_ids) if input_ids is not None else 0 + decode_len = len(output_ids) if output_ids else 0 + total_len = prompt_len + decode_len + + all_kv_indices: Optional[torch.Tensor] = None + if slot is not None and input_ids is not None: + all_kv_indices = runner.req_to_token_pool.req_to_token[slot, :total_len].to( + torch.int64 + ) + + # ---------------------------------------------------------- + # Phase 2: Insert into radix cache (if enabled). + # ---------------------------------------------------------- + did_insert = False + if cache_enabled and all_kv_indices is not None: + if self._is_hybrid and decode_len > 0: + # Hybrid model: insert only prompt tokens (not decode) + # because GDN state is only tracked at the prompt boundary. + prompt_kv = all_kv_indices[:prompt_len] + decode_kv = all_kv_indices[prompt_len:] + key = RadixKey(list(input_ids)) + result = cache.insert(key, prompt_kv) + new_prefix_len = result.prefix_len + + # Free duplicate KV indices in the overlap region. + if new_prefix_len > cache_protected_len: + dup_indices = prompt_kv[cache_protected_len:new_prefix_len] + if dup_indices.numel() > 0: + runner.token_to_kv_pool_allocator.free(dup_indices) + + # Free decode KV indices (tree does not own them) + if decode_kv.numel() > 0: + runner.token_to_kv_pool_allocator.free(decode_kv) + else: + # Non-hybrid or no decode tokens: insert full sequence + full_token_ids = list(input_ids) + if output_ids: + full_token_ids.extend(output_ids) + key = RadixKey(full_token_ids) + result = cache.insert(key, all_kv_indices) + new_prefix_len = result.prefix_len + + # Free duplicate KV indices in the overlap region. + if new_prefix_len > cache_protected_len: + dup_indices = all_kv_indices[cache_protected_len:new_prefix_len] + if dup_indices.numel() > 0: + runner.token_to_kv_pool_allocator.free(dup_indices) + + did_insert = True + + # ---------------------------------------------------------- + # Phase 3: Unlock radix cache nodes. + # ---------------------------------------------------------- + if cache_enabled and radix_lock is not None: + node, swa_boundary_id = radix_lock + cache.dec_lock_ref(node, swa_boundary_id) + + # ---------------------------------------------------------- + # Phase 4: Free KV indices not owned by the radix cache. + # ---------------------------------------------------------- + if not did_insert: + if cache_enabled and all_kv_indices is not None: + # Cache enabled but insert skipped (shouldn't happen in + # normal flow). Tree owns [0, cache_protected_len); + # free the rest. + tail = all_kv_indices[cache_protected_len:] + if tail.numel() > 0: + runner.token_to_kv_pool_allocator.free(tail) + elif not cache_enabled: + # Cache disabled — free all newly-allocated KV indices. + if all_kv_indices is not None and all_kv_indices.numel() > 0: + runner.token_to_kv_pool_allocator.free(all_kv_indices) + elif kv_indices is not None and kv_indices.numel() > 0: + runner.token_to_kv_pool_allocator.free(kv_indices) + + # ---------------------------------------------------------- + # Phase 5: Free the req pool slot. + # ---------------------------------------------------------- + if slot is not None: + runner.req_to_token_pool.free(slot) + + logger.debug( + "Freed resources for rid=%s (slot=%s, kv_tokens=%d)", + rid, + slot, + kv_indices.numel() if kv_indices is not None else 0, + ) + + # ------------------------------------------------------------------ + # GDN state tracking helpers (hybrid models) + # ------------------------------------------------------------------ + + def _track_gdn_state_after_extend( + self, requests_meta: List[Dict[str, Any]] + ) -> None: + """Snapshot working GDN state into each request's track slot. + + Called immediately after ``runner.forward()`` for extend batches so + that the FINAL recurrent/conv state (after processing the full prompt) + is saved. The track slot is later associated with a radix node in + ``_insert_into_radix_cache``. + """ + gdn_pool = getattr(self._runner, "gdn_pool", None) + if gdn_pool is None: + return + + for m in requests_meta: + rid = m["rid"] + working_slot = self._rid_to_req_pool_idx.get(rid) + track_slot = self._rid_to_gdn_track_slot.get(rid) + if working_slot is not None and track_slot is not None: + gdn_pool.copy_states(working_slot, track_slot) + + def _on_radix_node_evict(self, node_id: int) -> None: + """Callback invoked by RadixCache when a node is evicted. + + Frees the GDN track slot associated with the evicted node. + """ + track_slot = self._node_id_to_gdn_track_slot.pop(node_id, None) + if track_slot is not None: + gdn_pool = getattr(self._runner, "gdn_pool", None) + if gdn_pool is not None: + gdn_pool.free_track_slot(track_slot) + logger.debug( + "Freed GDN track slot %d for evicted node %d", + track_slot, node_id, + ) + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def shutdown(self) -> None: + if self._runner is not None: + self._runner.shutdown() diff --git a/pymllm/orchestrator/parallel_state.py b/pymllm/orchestrator/parallel_state.py new file mode 100644 index 000000000..9fb208769 --- /dev/null +++ b/pymllm/orchestrator/parallel_state.py @@ -0,0 +1,183 @@ +"""Minimal parallel state for single-GPU serving. + +pymllm targets single-GPU, high-concurrency inference. This module keeps +the TP / DP / PP scaffolding so the rest of the codebase can query ranks +and groups uniformly, but the default (and expected) case is world_size=1. +""" + +import logging +from typing import Optional + +import torch +import torch.distributed as dist + +from pymllm.orchestrator.group_coordinator import GroupCoordinator + +logger = logging.getLogger(__name__) + +_TP_GROUP: Optional[GroupCoordinator] = None +_DP_GROUP: Optional[GroupCoordinator] = None +_PP_GROUP: Optional[GroupCoordinator] = None + +_TP_RANK: int = 0 +_TP_SIZE: int = 1 +_DP_RANK: int = 0 +_DP_SIZE: int = 1 +_PP_RANK: int = 0 +_PP_SIZE: int = 1 + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + data_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: str = "nccl", +) -> None: + global _TP_GROUP, _DP_GROUP, _PP_GROUP + global _TP_RANK, _TP_SIZE, _DP_RANK, _DP_SIZE, _PP_RANK, _PP_SIZE + + _TP_SIZE = tensor_model_parallel_size + _DP_SIZE = data_parallel_size + _PP_SIZE = pipeline_model_parallel_size + + if not dist.is_initialized(): + return + + world_size = dist.get_world_size() + world_rank = dist.get_rank() + local_rank = int(torch.cuda.current_device()) if torch.cuda.is_available() else 0 + + assert ( + tensor_model_parallel_size * data_parallel_size * pipeline_model_parallel_size + == world_size + ), ( + f"TP({tensor_model_parallel_size}) * DP({data_parallel_size}) * " + f"PP({pipeline_model_parallel_size}) != World({world_size})" + ) + + logger.info( + "Parallel init: world=%d rank=%d tp=%d dp=%d pp=%d", + world_size, + world_rank, + tensor_model_parallel_size, + data_parallel_size, + pipeline_model_parallel_size, + ) + + if tensor_model_parallel_size > 1: + num_tp_groups = world_size // tensor_model_parallel_size + for i in range(num_tp_groups): + ranks = list( + range( + i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size, + ) + ) + if world_rank in ranks: + _TP_GROUP = GroupCoordinator( + ranks=ranks, + local_rank=local_rank, + backend=backend, + ) + _TP_RANK = _TP_GROUP.rank_in_group + break + + if data_parallel_size > 1: + num_dp_groups = world_size // data_parallel_size + for i in range(num_dp_groups): + ranks = list(range(i, world_size, num_dp_groups)) + if world_rank in ranks: + _DP_GROUP = GroupCoordinator( + ranks=ranks, + local_rank=local_rank, + backend=backend, + ) + _DP_RANK = _DP_GROUP.rank_in_group + break + + if pipeline_model_parallel_size > 1: + num_pp_groups = world_size // pipeline_model_parallel_size + for i in range(num_pp_groups): + start = i * pipeline_model_parallel_size + ranks = list(range(start, start + pipeline_model_parallel_size)) + if world_rank in ranks: + _PP_GROUP = GroupCoordinator( + ranks=ranks, + local_rank=local_rank, + backend=backend, + ) + _PP_RANK = _PP_GROUP.rank_in_group + break + + +# ---- group accessors ------------------------------------------------------ + + +def get_tp_group() -> Optional[GroupCoordinator]: + return _TP_GROUP + + +def get_dp_group() -> Optional[GroupCoordinator]: + return _DP_GROUP + + +def get_pp_group() -> Optional[GroupCoordinator]: + return _PP_GROUP + + +# ---- rank / size helpers -------------------------------------------------- + + +def get_tensor_model_parallel_rank() -> int: + return _TP_RANK + + +def get_tensor_model_parallel_world_size() -> int: + return _TP_SIZE + + +def get_data_parallel_rank() -> int: + return _DP_RANK + + +def get_data_parallel_world_size() -> int: + return _DP_SIZE + + +def get_pipeline_model_parallel_rank() -> int: + return _PP_RANK + + +def get_pipeline_model_parallel_world_size() -> int: + return _PP_SIZE + + +def model_parallel_is_initialized() -> bool: + return _TP_GROUP is not None or _DP_GROUP is not None or _PP_GROUP is not None + + +# ---- communication helpers ------------------------------------------------ + + +def tensor_model_parallel_all_reduce(tensor: torch.Tensor) -> torch.Tensor: + group = get_tp_group() + if group is None: + return tensor + return group.all_reduce(tensor) + + +def tensor_model_parallel_all_gather( + tensor: torch.Tensor, + dim: int = 0, +) -> torch.Tensor: + group = get_tp_group() + if group is None: + return tensor + return group.all_gather(tensor, dim=dim) + + +def data_parallel_all_reduce(tensor: torch.Tensor) -> torch.Tensor: + group = get_dp_group() + if group is None: + return tensor + return group.all_reduce(tensor) diff --git a/pymllm/orchestrator/request_response_process.py b/pymllm/orchestrator/request_response_process.py new file mode 100644 index 000000000..5c72a14c4 --- /dev/null +++ b/pymllm/orchestrator/request_response_process.py @@ -0,0 +1,189 @@ +""" +RequestResponseProcess -- the main-process entry point for user requests. + +This process is **not** a subprocess; it lives in the engine's main process. +Incoming requests are placed into an ``asyncio.Queue`` and forwarded to the +TokenizerProcess via ZMQ. Decoded results arrive back from the +DetokenizerProcess and are dispatched to the waiting callers. + +The request-tracking model uses ``ReqState`` pattern: each request +gets an ``asyncio.Event`` + output list so that streaming (multiple incremental +chunks) and one-shot responses are both supported. +""" + +import asyncio +import dataclasses +import logging +from typing import Any, Dict, List, Optional, Union + +import zmq +import zmq.asyncio + +from pymllm.engine.io_struct import GenerateReqInput +from pymllm.orchestrator.ipc_utils import create_zmq_socket, close_zmq_socket + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ReqState: + """Per-request state that supports both streaming and one-shot responses. + + ``ReqState`` (Event + out_list). + + The recv loop appends results to *out_list* and signals *event*; + callers ``await event.wait()`` in a loop, consuming results until + *finished* is ``True``. + """ + + out_list: List[Dict[str, Any]] = dataclasses.field(default_factory=list) + finished: bool = False + event: asyncio.Event = dataclasses.field(default_factory=asyncio.Event) + + +class RequestResponseProcess: + """Sits in the main process; bridges user-facing API and subprocess pipeline.""" + + def __init__( + self, + send_to_tokenizer_addr: str, + recv_from_detokenizer_addr: str, + ): + self._send_to_tokenizer_addr: str = send_to_tokenizer_addr + self._recv_from_detokenizer_addr: str = recv_from_detokenizer_addr + + # asyncio queue that buffers incoming user requests + self._request_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() + + # rid -> ReqState (replaces the old rid -> Future dict) + self._rid_to_state: Dict[str, ReqState] = {} + + # ZMQ (async context, sockets created lazily in the event loop) + self._zmq_ctx: Optional[zmq.asyncio.Context] = None + self._send_to_tokenizer: Optional[zmq.asyncio.Socket] = None + self._recv_from_detokenizer: Optional[zmq.asyncio.Socket] = None + + self._loop_task: Optional[asyncio.Task] = None + + def start(self) -> None: + """Bind ZMQ sockets. Background tasks are started lazily by + :meth:`listen` on the first :meth:`add_request` call, so they + always run on the correct event loop regardless of whether the + caller is uvicorn, ``loop.run_until_complete``, or anything else. + """ + self._zmq_ctx = zmq.asyncio.Context() + self._send_to_tokenizer = create_zmq_socket( + self._zmq_ctx, + zmq.PUSH, + self._send_to_tokenizer_addr, + bind=True, + ) + self._recv_from_detokenizer = create_zmq_socket( + self._zmq_ctx, + zmq.PULL, + self._recv_from_detokenizer_addr, + bind=True, + ) + + def listen(self) -> None: + """Start the send/recv background tasks on the **current** running + event loop. Idempotent — subsequent calls are no-ops while the + tasks are still alive. + + Called automatically by :meth:`add_request`, so callers never need + to invoke this directly. + """ + if self._loop_task is not None and not self._loop_task.done(): + return + loop = asyncio.get_running_loop() + self._loop_task = loop.create_task(self._run()) + logger.debug("RequestResponseProcess: background tasks started") + + async def add_request( + self, request: GenerateReqInput + ) -> Union[ReqState, List[ReqState]]: + """Enqueue request(s) and return the corresponding :class:`ReqState`(s). + + * **Single request** (``request.is_single is True``): behaves exactly as + before – registers one ``ReqState`` and enqueues one message. + * **Batch request** (``request.is_single is False``): splits the batch + into *N* individual sub-requests, registers a ``ReqState`` per rid, and + enqueues each sub-request separately so the downstream pipeline sees + independent messages. Returns a ``List[ReqState]`` in the same order + as the input rids. + + Callers should ``await state.event.wait()`` in a loop, consuming + ``state.out_list`` entries until ``state.finished`` is ``True``. + """ + self.listen() + + if request.is_single: + rid = request.rid if isinstance(request.rid, str) else request.rid[0] + state = ReqState() + self._rid_to_state[rid] = state + await self._request_queue.put(request.to_request_dict()) + return state + + # Batch path: fan-out into individual sub-requests. + states: List[ReqState] = [] + for i in range(request.batch_size): + sub = request[i] + rid = sub.rid if isinstance(sub.rid, str) else str(sub.rid) + state = ReqState() + self._rid_to_state[rid] = state + await self._request_queue.put(sub.to_request_dict()) + states.append(state) + return states + + def remove_state(self, rid: str) -> None: + """Remove the ``ReqState`` for *rid* (called by the caller once done).""" + self._rid_to_state.pop(rid, None) + + async def abort_request(self, rid: str) -> None: + """Cancel a pending request and notify downstream processes.""" + state = self._rid_to_state.pop(rid, None) + if state is not None and not state.finished: + state.finished = True + state.out_list.append({"rid": rid, "error": "aborted", "finished": True}) + state.event.set() + await self._send_to_tokenizer.send_pyobj({"rid": rid, "abort": True}) + + async def shutdown(self) -> None: + if self._loop_task is not None: + self._loop_task.cancel() + if self._send_to_tokenizer is not None: + close_zmq_socket(self._send_to_tokenizer) + if self._recv_from_detokenizer is not None: + close_zmq_socket(self._recv_from_detokenizer) + if self._zmq_ctx is not None: + self._zmq_ctx.term() + + # ------------------------------------------------------------------ + # Internal loops + # ------------------------------------------------------------------ + + async def _run(self) -> None: + """Main loop: forward requests to tokenizer, receive results from detokenizer.""" + send_task = asyncio.create_task(self._send_loop()) + recv_task = asyncio.create_task(self._recv_loop()) + await asyncio.gather(send_task, recv_task) + + async def _send_loop(self) -> None: + """Drain the asyncio queue and push requests to the TokenizerProcess.""" + while True: + request = await self._request_queue.get() + await self._send_to_tokenizer.send_pyobj(request) + + async def _recv_loop(self) -> None: + """Receive decoded results from DetokenizerProcess and dispatch to ReqStates.""" + while True: + result = await self._recv_from_detokenizer.recv_pyobj() + rid = result.get("rid") + state = self._rid_to_state.get(rid) + if state is None: + logger.warning("Received result for unknown rid=%s", rid) + continue + state.out_list.append(result) + if result.get("finished", False): + state.finished = True + state.event.set() diff --git a/pymllm/orchestrator/scheduler_process.py b/pymllm/orchestrator/scheduler_process.py new file mode 100644 index 000000000..8594a8997 --- /dev/null +++ b/pymllm/orchestrator/scheduler_process.py @@ -0,0 +1,1017 @@ +""" +SchedulerProcess -- the central scheduling and inference hub. + +Receives tokenized requests from the TokenizerProcess, organises them into +batches, runs model forward passes via the **in-process** model runner +(sglang-style), and streams finished token IDs to the DetokenizerProcess. + +Architecture: the scheduler owns the :class:`ModelRunnerProcess` directly +(same process, direct function calls). GPU resources (KV cache, req pool +slots) are freed immediately when requests finish — no cross-process +communication needed. + +Request ingestion supports two modes: + 1. ZMQ path: Receive TokenizedGenerateReqInput via ZMQ recv_pyobj + 2. Shared queue fast path: Read from shared memory + multiprocessing queue + +The main ``event_loop``:: + + while True: + recv_requests() + process_input_requests() + batch = get_next_batch_to_run() # also frees finished GPU resources + if batch: + result = run_batch(batch) # direct call to model runner + process_batch_result(batch, result) + stream_output() +""" + +import logging +import queue as stdlib_queue +import time +from collections import deque +from multiprocessing.connection import Connection +from typing import Any, Deque, Dict, List, Optional + +import zmq + +from pymllm.engine.forward_batch import ForwardMode +from pymllm.engine.io_struct import BatchTokenIDOutput, TokenizedGenerateReqInput +from pymllm.orchestrator.cuda_ipc_transport import ( + TensorTransportMode, + unwrap_mm_inputs_from_ipc, +) +from pymllm.orchestrator.ipc_utils import create_zmq_socket, setup_subprocess_logging +from pymllm.orchestrator.shared_memory_queue import SharedMemoryManager, TensorQueue + +logger = logging.getLogger(__name__) + +# Default scheduling limits +_DEFAULT_MAX_RUNNING_REQUESTS = 256 +_DEFAULT_MAX_PREFILL_TOKENS = 8192 +_DEFAULT_MAX_TOTAL_TOKENS = 131072 +_DEFAULT_MAX_NEW_TOKENS = 32768 + + +# ====================================================================== +# Req -- per-request state tracker +# ====================================================================== + + +class Req: + """Tracks a single request through its lifecycle (prefill -> decode -> finish). + + Created by :meth:`SchedulerProcess.process_input_requests` from a + :class:`~pymllm.engine.io_struct.TokenizedGenerateReqInput`. + """ + + __slots__ = ( + "rid", + "input_ids", + "input_text", + "sampling_params", + "mm_inputs", + "stream", + "return_logprob", + "logprob_start_len", + "top_logprobs_num", + # KV-cache state + "req_pool_idx", + "seq_len", + # Prefix-cache hit (set during scheduling when radix cache is active) + "prefix_len", + # Generation state + "output_ids", + "finished_reason", + "is_prefilled", + # Sampling parameters (parsed) + "max_new_tokens", + "temperature", + "top_p", + "top_k", + "stop_token_ids", + # Streaming + "read_offset", + # Prompt length (for token accounting) + "prompt_len", + ) + + def __init__( + self, + rid: str, + input_ids: List[int], + input_text: str = "", + sampling_params: Optional[Dict[str, Any]] = None, + mm_inputs: Optional[Dict[str, Any]] = None, + stream: bool = False, + return_logprob: bool = False, + logprob_start_len: int = -1, + top_logprobs_num: int = 0, + ): + self.rid = rid + self.input_ids = list(input_ids) + self.input_text = input_text + self.mm_inputs = mm_inputs + self.stream = stream + self.return_logprob = return_logprob + self.logprob_start_len = logprob_start_len + self.top_logprobs_num = top_logprobs_num + + # Parse sampling params + sp = sampling_params or {} + self.sampling_params = sp + self.max_new_tokens: int = sp.get("max_new_tokens", _DEFAULT_MAX_NEW_TOKENS) + self.temperature: float = sp.get("temperature", 1.0) + self.top_p: float = sp.get("top_p", 1.0) + self.top_k: int = sp.get("top_k", -1) + self.stop_token_ids: List[int] = list(sp.get("stop_token_ids", [])) + + # KV-cache state (assigned during scheduling) + self.req_pool_idx: int = -1 + self.seq_len: int = len(input_ids) + # Number of prefix tokens served from the radix/KV cache (0 = no hit). + # Updated by process_batch_result when the model runner reports a + # prefix cache hit. Used in _free_req_resources to correctly + # release the token budget. + self.prefix_len: int = 0 + + # Generation state + self.output_ids: List[int] = [] + self.finished_reason: Optional[str] = None + self.is_prefilled: bool = False + + # Streaming + self.read_offset: int = 0 + + # Prompt length + self.prompt_len: int = len(input_ids) + + def check_finished(self, eos_token_id: Optional[int] = None) -> bool: + """Check if this request has reached a finish condition. + + Sets ``finished_reason`` and returns True if finished. + Checks: + 1. EOS token in the latest generated token + 2. ``max_new_tokens`` reached + """ + if self.finished_reason is not None: + return True + + if self.output_ids: + last_token = self.output_ids[-1] + # Check model EOS token + if eos_token_id is not None and last_token == eos_token_id: + self.finished_reason = "eos" + return True + # Check stop token IDs from sampling params + if last_token in self.stop_token_ids: + self.finished_reason = "eos" + return True + + # Check max_new_tokens + if len(self.output_ids) >= self.max_new_tokens: + self.finished_reason = "length" + return True + + return False + + @property + def is_finished(self) -> bool: + return self.finished_reason is not None + + def abort(self) -> None: + """Mark this request as aborted.""" + self.finished_reason = "abort" + + def __repr__(self) -> str: + return ( + f"Req(rid={self.rid!r}, seq_len={self.seq_len}, " + f"out={len(self.output_ids)}, finished={self.finished_reason})" + ) + + +# ====================================================================== +# ScheduleBatch -- batch container +# ====================================================================== + + +class ScheduleBatch: + """Wraps a list of :class:`Req` objects for a single forward pass. + + Provides helpers to assemble the batch dict sent to the ModelRunnerProcess + in the format expected by :class:`~pymllm.engine.forward_batch.ForwardBatch`. + """ + + def __init__(self, reqs: List[Req], forward_mode: ForwardMode): + self.reqs = reqs + self.forward_mode = forward_mode + + @property + def batch_size(self) -> int: + return len(self.reqs) + + def prepare_for_extend(self) -> Dict[str, Any]: + """Assemble a batch dict for prefill / extend forward pass. + + Returns a dict with flattened ``input_ids``, per-request ``positions``, + ``req_pool_indices``, ``seq_lens``, ``extend_seq_lens``, + ``extend_prefix_lens``, and request metadata. + + Note: The scheduler sends the **full** input_ids (no prefix trimming). + The ModelRunnerProcess performs radix cache prefix matching and + rebuilds the tensors with actual prefix lengths before the forward + pass. The ``extend_prefix_lens`` here are always 0 from the + scheduler; they serve as placeholders. + """ + all_input_ids: List[int] = [] + all_positions: List[int] = [] + req_pool_indices: List[int] = [] + seq_lens: List[int] = [] + extend_seq_lens: List[int] = [] + extend_prefix_lens: List[int] = [] + requests_meta: List[Dict[str, Any]] = [] + + for req in self.reqs: + input_len = len(req.input_ids) + + # Send full input_ids; model runner will trim based on prefix + all_input_ids.extend(req.input_ids) + all_positions.extend(range(input_len)) + req_pool_indices.append(req.req_pool_idx) + seq_lens.append(req.seq_len) + extend_seq_lens.append(input_len) + extend_prefix_lens.append(0) + requests_meta.append( + { + "rid": req.rid, + "input_ids": req.input_ids, + "mm_inputs": req.mm_inputs, + "sampling_params": req.sampling_params, + "return_logprob": req.return_logprob, + "logprob_start_len": req.logprob_start_len, + "top_logprobs_num": req.top_logprobs_num, + } + ) + + return { + "forward_mode": "extend", + "batch_size": self.batch_size, + "input_ids": all_input_ids, + "positions": all_positions, + "req_pool_indices": req_pool_indices, + "seq_lens": seq_lens, + "extend_seq_lens": extend_seq_lens, + "extend_prefix_lens": extend_prefix_lens, + "requests": requests_meta, + "batch_id": id(self), + "created_at": time.time(), + } + + def prepare_for_decode(self) -> Dict[str, Any]: + """Assemble a batch dict for decode forward pass (one token per request). + + Returns a dict with one input token per request (the last generated + token), positions at ``seq_len``, and request metadata. + """ + all_input_ids: List[int] = [] + all_positions: List[int] = [] + req_pool_indices: List[int] = [] + seq_lens: List[int] = [] + requests_meta: List[Dict[str, Any]] = [] + + for req in self.reqs: + # For decode, the input is the last generated token + if req.output_ids: + all_input_ids.append(req.output_ids[-1]) + else: + # Fallback: last input token (shouldn't happen normally) + all_input_ids.append(req.input_ids[-1]) + all_positions.append(req.seq_len) + req_pool_indices.append(req.req_pool_idx) + seq_lens.append(req.seq_len) + requests_meta.append( + { + "rid": req.rid, + "sampling_params": req.sampling_params, + "return_logprob": req.return_logprob, + "logprob_start_len": req.logprob_start_len, + "top_logprobs_num": req.top_logprobs_num, + } + ) + + return { + "forward_mode": "decode", + "batch_size": self.batch_size, + "input_ids": all_input_ids, + "positions": all_positions, + "req_pool_indices": req_pool_indices, + "seq_lens": seq_lens, + "requests": requests_meta, + "batch_id": id(self), + "created_at": time.time(), + } + + def to_batch_dict(self) -> Dict[str, Any]: + """Build the batch dict appropriate for the current forward mode.""" + if self.forward_mode.is_extend(): + return self.prepare_for_extend() + else: + return self.prepare_for_decode() + + def __repr__(self) -> str: + return f"ScheduleBatch(mode={self.forward_mode.name}, size={self.batch_size})" + + +# ====================================================================== +# SchedulerProcess +# ====================================================================== + + +class SchedulerProcess: + """Runs inside a subprocess. Central hub that drives the inference loop.""" + + def __init__( + self, + recv_from_tokenizer_addr: str, + send_to_detokenizer_addr: str, + server_config: Optional[Any] = None, + model_config: Optional[Any] = None, + gpu_id: int = 0, + shared_queue: Optional[TensorQueue] = None, + enable_shared_queue: bool = False, + tensor_transport_mode: TensorTransportMode = "default", + # Scheduling limits + max_running_requests: int = _DEFAULT_MAX_RUNNING_REQUESTS, + max_prefill_tokens: int = _DEFAULT_MAX_PREFILL_TOKENS, + max_total_tokens: int = _DEFAULT_MAX_TOTAL_TOKENS, + eos_token_ids: Optional[List[int]] = None, + default_max_new_tokens: int = _DEFAULT_MAX_NEW_TOKENS, + ): + # ZMQ addresses (tokenizer + detokenizer only) + self._recv_from_tokenizer_addr = recv_from_tokenizer_addr + self._send_to_detokenizer_addr = send_to_detokenizer_addr + + # Model config (for in-process model runner, sglang-style) + self._server_config = server_config + self._model_config = model_config + self._gpu_id = gpu_id + + # Shared queue configuration + self._shared_queue = shared_queue + self._enable_shared_queue = enable_shared_queue + self._tensor_transport_mode = tensor_transport_mode + + # ZMQ runtime objects (initialised in init_sockets) + self._zmq_ctx: Optional[zmq.Context] = None + self._recv_from_tokenizer: Optional[zmq.Socket] = None + self._send_to_detokenizer: Optional[zmq.Socket] = None + self._poller: Optional[zmq.Poller] = None + + # In-process model runner (initialised in init_model) + self._model_runner = None + + # Request management -- three-stage pipeline + self._waiting_queue: Deque[TokenizedGenerateReqInput] = deque() + self._pending_queue: List[Req] = [] + self._running_batch: List[Req] = [] + self._finished: List[Dict[str, Any]] = [] + + # Scheduling limits + self._max_running_requests = max_running_requests + self._max_prefill_tokens = max_prefill_tokens + + # KV-cache token budget (simplified single-GPU tracking). + self._max_total_tokens = max_total_tokens + self._used_tokens: int = 0 + + # EOS token(s) for finish detection + self._eos_token_ids: List[int] = list(eos_token_ids) if eos_token_ids else [] + + # Default max_new_tokens (from model config or fallback) + self._default_max_new_tokens = default_max_new_tokens + + # Monotonic request-slot counter (simplified; no GPU pool access) + self._next_req_pool_idx: int = 0 + + # ------ Throughput metrics (sglang-style interval logging) ------ + # How often (in decode batches) to log throughput stats. + self._decode_log_interval: int = ( + server_config.decode_log_interval + if server_config is not None and hasattr(server_config, "decode_log_interval") + else 40 + ) + # Accumulators reset at each log interval + self._num_prefill_tokens: int = 0 # new prefill tokens (excluding cache hits) + self._num_prefill_cache_tokens: int = 0 # prefill tokens served from cache + self._num_decode_tokens: int = 0 # generated decode tokens + self._num_prefill_reqs: int = 0 # prefill requests count + # Timestamps for throughput calculation + self._last_prefill_stats_tic: float = time.time() + self._last_decode_stats_tic: float = time.time() + # Forward pass counters + self._forward_ct_decode: int = 0 + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def init_sockets(self) -> None: + self._zmq_ctx = zmq.Context() + + self._recv_from_tokenizer = create_zmq_socket( + self._zmq_ctx, + zmq.PULL, + self._recv_from_tokenizer_addr, + bind=False, + ) + self._send_to_detokenizer = create_zmq_socket( + self._zmq_ctx, + zmq.PUSH, + self._send_to_detokenizer_addr, + bind=True, + ) + + # Poller for non-blocking recv from tokenizer + self._poller = zmq.Poller() + self._poller.register(self._recv_from_tokenizer, zmq.POLLIN) + + def init_model(self) -> None: + """Create and initialise the in-process model runner (sglang-style). + + Must be called after ``init_sockets`` and inside the subprocess + (after spawn) since it performs CUDA initialisation. + """ + from pymllm.orchestrator.model_runner_process import ModelRunnerProcess + + self._model_runner = ModelRunnerProcess( + gpu_id=self._gpu_id, + server_config=self._server_config, + model_config=self._model_config, + ) + self._model_runner.init_model() + logger.info("In-process model runner initialised on GPU %d", self._gpu_id) + + def event_loop(self) -> None: + """Infinite scheduling loop.""" + logger.info( + "SchedulerProcess event loop started (shared_queue=%s, transport=%s)", + self._enable_shared_queue, + self._tensor_transport_mode, + ) + while True: + self.recv_requests() + self.process_input_requests() + batch = self.get_next_batch_to_run() + if batch is not None: + result = self.run_batch(batch) + self.process_batch_result(batch, result) + self.stream_output() + + # ------------------------------------------------------------------ + # Step 1: receive tokenized requests (non-blocking) + # ------------------------------------------------------------------ + + def recv_requests(self) -> None: + """Non-blocking receive of tokenized requests from TokenizerProcess. + + Supports two modes: + 1. Legacy ZMQ: Uses ``zmq.Poller`` with a short timeout + 2. Shared queue: Non-blocking get from multiprocessing.Queue + + Messages are either: + * A :class:`~pymllm.engine.io_struct.TokenizedGenerateReqInput` + dataclass – appended to ``_waiting_queue``. + * A plain abort sentinel dict ``{"rid": ..., "abort": True}`` – handled + inline by removing the matching rid from the waiting queue. + """ + if self._enable_shared_queue and self._shared_queue is not None: + self._recv_from_shared_queue() + else: + self._recv_from_zmq() + + def _recv_from_zmq(self) -> None: + """Receive requests via legacy ZMQ path.""" + while True: + events = dict(self._poller.poll(timeout=0)) # non-blocking + if self._recv_from_tokenizer not in events: + break + msg = self._recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) + # Abort sentinel: plain dict with "abort" key. + if isinstance(msg, dict) and msg.get("abort"): + rid = msg.get("rid") + logger.debug("Scheduler received abort for rid=%s", rid) + self._waiting_queue = type(self._waiting_queue)( + r for r in self._waiting_queue if r.rid != rid + ) + # Also abort from pending queue + self._abort_request(rid) + else: + self._waiting_queue.append(msg) + + def _recv_from_shared_queue(self) -> None: + """Receive requests via shared memory + shared queue fast path. + + After reading a ``(rid, shm_name, mm_inputs)`` tuple from the queue: + 1. The tokenized metadata is read from the POSIX shared memory segment. + 2. If CUDA IPC is enabled, ``mm_inputs`` may contain + :class:`~pymllm.orchestrator.cuda_ipc_transport.CudaIpcTensorTransportProxy` + or :class:`~pymllm.orchestrator.cuda_ipc_transport.TransportProxyTensor` + objects that are reconstructed by calling + :func:`~pymllm.orchestrator.cuda_ipc_transport.unwrap_mm_inputs_from_ipc`. + This step also increments sync flags so the sender can recycle pool chunks. + 3. A full ``TokenizedGenerateReqInput`` is assembled and appended to + ``_waiting_queue``. + """ + while True: + try: + rid, shm_name, mm_inputs = self._shared_queue.get(timeout=0.0001) + + # Read metadata from shared memory (and unlink immediately) + metadata: TokenizedGenerateReqInput = SharedMemoryManager.read_metadata( + shm_name, unlink=True + ) + + # Reconstruct GPU tensors from CUDA IPC handles (if any) + if self._tensor_transport_mode in ("cuda_ipc", "cuda_ipc_pool"): + mm_inputs = unwrap_mm_inputs_from_ipc(mm_inputs) + + # Reassemble the full request + full_request = TokenizedGenerateReqInput( + rid=metadata.rid, + input_text=metadata.input_text, + input_ids=metadata.input_ids, + mm_inputs=mm_inputs, + sampling_params=metadata.sampling_params, + stream=metadata.stream, + return_logprob=metadata.return_logprob, + logprob_start_len=metadata.logprob_start_len, + top_logprobs_num=metadata.top_logprobs_num, + lora_path=metadata.lora_path, + session_params=metadata.session_params, + ) + + self._waiting_queue.append(full_request) + logger.debug("Received request %s from shared queue", rid) + + except stdlib_queue.Empty: + break + except Exception as exc: + logger.error( + "Error receiving from shared queue: %s", exc, exc_info=True + ) + try: + if "shm_name" in locals(): + SharedMemoryManager.cleanup(shm_name) + except Exception: + pass + break + + # ------------------------------------------------------------------ + # Step 2: process input requests + # ------------------------------------------------------------------ + + def process_input_requests(self) -> None: + """Convert raw :class:`TokenizedGenerateReqInput` in ``_waiting_queue`` + into :class:`Req` objects and move them to ``_pending_queue``. + + For each request: + 1. Parse sampling params (max_new_tokens, temperature, top_p, top_k, + stop_token_ids with defaults from EOS token). + 2. Create a ``Req`` object. + 3. Move from ``_waiting_queue`` to ``_pending_queue``. + """ + while self._waiting_queue: + raw = self._waiting_queue.popleft() + + # Merge EOS token into stop_token_ids if not already present + sp = dict(raw.sampling_params) if raw.sampling_params else {} + # Inject model-aware default for max_new_tokens when not provided + if "max_new_tokens" not in sp: + sp["max_new_tokens"] = self._default_max_new_tokens + stop_ids = list(sp.get("stop_token_ids", [])) + for eid in self._eos_token_ids: + if eid not in stop_ids: + stop_ids.append(eid) + sp["stop_token_ids"] = stop_ids + + req = Req( + rid=raw.rid, + input_ids=raw.input_ids, + input_text=raw.input_text, + sampling_params=sp, + mm_inputs=raw.mm_inputs, + stream=raw.stream, + return_logprob=raw.return_logprob, + logprob_start_len=raw.logprob_start_len, + top_logprobs_num=raw.top_logprobs_num, + ) + self._pending_queue.append(req) + logger.debug("Processed input request %s (len=%d)", req.rid, req.seq_len) + + # ------------------------------------------------------------------ + # Step 3: build the next batch + # ------------------------------------------------------------------ + + def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: + """Implements continuous batching with two phases. + + 1. **Filter finished**: Remove finished requests from + ``_running_batch`` and free their token budget. + 2. **Schedule new prefills**: From ``_pending_queue``, admit + requests that fit within the token budget and + ``max_running_requests``. + 3. **Build batch**: + - If new prefill requests exist -> EXTEND batch + - Else if running decode requests exist -> DECODE batch + - Else -> None (idle) + + Note on prefix cache: The actual prefix matching is done by the + ModelRunnerProcess (which owns the RadixCache). The scheduler + uses ``input_len`` as a conservative budget estimate. The model + runner reports back actual ``prefix_len`` in results, and the + scheduler adjusts ``_used_tokens`` accordingly in + ``process_batch_result``. + """ + # Phase 1: filter finished requests from running batch + still_running: List[Req] = [] + for req in self._running_batch: + if req.is_finished: + self._model_runner._free_rid_resources(req.rid) + self._free_req_resources(req) + else: + still_running.append(req) + self._running_batch = still_running + + # Phase 2: schedule new prefill requests from pending queue + new_prefill: List[Req] = [] + remaining_pending: List[Req] = [] + prefill_token_budget = self._max_prefill_tokens + + for req in self._pending_queue: + input_len = len(req.input_ids) + total_running = len(self._running_batch) + len(new_prefill) + + # Check capacity constraints. + # We reserve the full input_len as KV budget (conservative). + # If the model runner finds a prefix cache hit, some tokens + # won't need new KV allocation; the budget is corrected in + # process_batch_result. + can_fit_request = total_running < self._max_running_requests + can_fit_tokens = (self._used_tokens + input_len) <= self._max_total_tokens + can_fit_prefill = input_len <= prefill_token_budget + + if can_fit_request and can_fit_tokens and can_fit_prefill: + # Allocate req pool slot + req.req_pool_idx = self._next_req_pool_idx + self._next_req_pool_idx += 1 + # Reserve token budget (full input_len as conservative estimate) + self._used_tokens += input_len + prefill_token_budget -= input_len + new_prefill.append(req) + logger.debug( + "Scheduled prefill for %s (len=%d, used=%d/%d)", + req.rid, + input_len, + self._used_tokens, + self._max_total_tokens, + ) + else: + remaining_pending.append(req) + + self._pending_queue = remaining_pending + + # Phase 3: build batch + if new_prefill: + return ScheduleBatch(new_prefill, ForwardMode.EXTEND) + elif self._running_batch: + return ScheduleBatch(self._running_batch, ForwardMode.DECODE) + else: + return None + + # ------------------------------------------------------------------ + # Step 4: run the batch via ModelRunnerProcess + # ------------------------------------------------------------------ + + def run_batch(self, batch: ScheduleBatch) -> Dict[str, Any]: + """Execute the batch via the in-process model runner (sglang-style). + + Direct function call — no ZMQ serialisation overhead. + """ + batch_dict = batch.to_batch_dict() + return self._model_runner._forward_batch(batch_dict) + + # ------------------------------------------------------------------ + # Step 5: process batch result + # ------------------------------------------------------------------ + + def process_batch_result( + self, batch: ScheduleBatch, result: Dict[str, Any] + ) -> None: + """Handle the result returned by the ModelRunnerProcess. + + For each request in the result: + 1. Update ``prefix_len`` from the model runner's radix cache hit. + 2. Adjust ``_used_tokens`` if a prefix cache hit was found (the + scheduler over-reserved during scheduling). + 3. Append new token(s) to ``req.output_ids``. + 4. Increment ``req.seq_len``. + 5. Call ``req.check_finished()`` (EOS token, max_new_tokens). + 6. If prefill request: mark ``req.is_prefilled = True``, move to + running batch for decode. + 7. If finished: collect for output, free KV-cache budget. + """ + # Build a rid -> Req lookup for the batch + rid_to_req: Dict[str, Req] = {req.rid: req for req in batch.reqs} + + # The result may contain per-request outputs in "finished" and + # "unfinished" lists, or a flat "outputs" list. Handle both. + output_items: List[Dict[str, Any]] = [] + output_items.extend(result.get("finished", [])) + output_items.extend(result.get("unfinished", [])) + if "outputs" in result: + output_items.extend(result["outputs"]) + + for out in output_items: + rid = out.get("rid") + req = rid_to_req.get(rid) + if req is None: + logger.warning("Result for unknown rid=%s, skipping", rid) + continue + + # Update prefix_len from model runner's radix cache matching. + # The model runner reports the actual prefix_len it found. + # The scheduler originally reserved full input_len in + # get_next_batch_to_run; correct the over-reservation now. + if "prefix_len" in out and batch.forward_mode.is_extend(): + actual_prefix_len = out["prefix_len"] + if actual_prefix_len > req.prefix_len: + saved = actual_prefix_len - req.prefix_len + req.prefix_len = actual_prefix_len + # Give back the over-reserved tokens. The model runner + # reused cached KV for `saved` tokens, so those tokens + # do not consume new KV pool slots. + self._used_tokens = max(0, self._used_tokens - saved) + logger.info( + "Prefix cache hit for rid=%s: %d tokens reused, " + "budget adjusted by -%d (used=%d/%d)", + rid, + actual_prefix_len, + saved, + self._used_tokens, + self._max_total_tokens, + ) + + # Append generated token(s) + new_token_ids = out.get("output_token_ids", []) + if isinstance(new_token_ids, int): + new_token_ids = [new_token_ids] + req.output_ids.extend(new_token_ids) + req.seq_len += len(new_token_ids) + + # Update token budget for newly generated tokens + self._used_tokens += len(new_token_ids) + + # Check finish conditions + req.check_finished(eos_token_id=self._eos_token_ids[0] if self._eos_token_ids else None) + + # Process batch requests based on forward mode + if batch.forward_mode.is_extend(): + # Prefill batch: mark as prefilled and route + for req in batch.reqs: + req.is_prefilled = True + if req.is_finished: + self._collect_finished_output(req) + self._model_runner._free_rid_resources(req.rid) + self._free_req_resources(req) + else: + self._running_batch.append(req) + + # --- Accumulate prefill metrics --- + total_input = 0 + total_cached = 0 + for req in batch.reqs: + total_input += req.prompt_len + total_cached += req.prefix_len + self._num_prefill_tokens += total_input - total_cached + self._num_prefill_cache_tokens += total_cached + self._num_prefill_reqs += len(batch.reqs) + self._log_prefill_stats() + else: + # Decode batch: check finish and collect + new_running: List[Req] = [] + for req in batch.reqs: + if req.is_finished: + self._collect_finished_output(req) + self._model_runner._free_rid_resources(req.rid) + self._free_req_resources(req) + else: + new_running.append(req) + self._running_batch = new_running + + # --- Accumulate decode metrics --- + self._num_decode_tokens += batch.batch_size # 1 token per request + self._forward_ct_decode += 1 + if ( + self._decode_log_interval > 0 + and self._forward_ct_decode % self._decode_log_interval == 0 + ): + self._log_decode_stats() + + # ------------------------------------------------------------------ + # Step 6: stream output to DetokenizerProcess + # ------------------------------------------------------------------ + + def stream_output(self) -> None: + """Send finished/streaming outputs to the DetokenizerProcess. + + Produces :class:`~pymllm.engine.io_struct.BatchTokenIDOutput`-compatible + dicts. For streaming requests, intermediate tokens are also sent. + """ + # Collect streaming outputs from running requests + for req in self._running_batch: + if req.stream and len(req.output_ids) > req.read_offset: + decode_ids = req.output_ids[req.read_offset :] + output = { + "rids": [req.rid], + "finished_reasons": [None], + "decode_ids": decode_ids, + "read_offsets": [req.read_offset], + "output_ids": list(req.output_ids), + "skip_special_tokens": [True], + "prompt_tokens": [req.prompt_len], + "completion_tokens": [len(req.output_ids)], + } + req.read_offset = len(req.output_ids) + self._send_to_detokenizer.send_pyobj(output) + + # Send finished outputs + while self._finished: + item = self._finished.pop(0) + self._send_to_detokenizer.send_pyobj(item) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _log_prefill_stats(self) -> None: + """Log prefill throughput at INFO level (called after each prefill batch).""" + now = time.time() + elapsed = now - self._last_prefill_stats_tic + self._last_prefill_stats_tic = now + + if elapsed > 0: + input_throughput = self._num_prefill_tokens / elapsed + else: + input_throughput = 0.0 + + logger.info( + "Prefill batch: %d reqs, " + "new tokens: %d, " + "cached tokens: %d, " + "input throughput: %.2f token/s", + self._num_prefill_reqs, + self._num_prefill_tokens, + self._num_prefill_cache_tokens, + input_throughput, + ) + # Reset accumulators + self._num_prefill_tokens = 0 + self._num_prefill_cache_tokens = 0 + self._num_prefill_reqs = 0 + + def _log_decode_stats(self) -> None: + """Log decode throughput at INFO level (called every decode_log_interval batches).""" + now = time.time() + elapsed = now - self._last_decode_stats_tic + self._last_decode_stats_tic = now + + if elapsed > 0: + gen_throughput = self._num_decode_tokens / elapsed + else: + gen_throughput = 0.0 + + logger.info( + "Decode: %d steps, " + "gen tokens: %d, " + "running: %d reqs, " + "gen throughput: %.2f token/s", + self._forward_ct_decode, + self._num_decode_tokens, + len(self._running_batch), + gen_throughput, + ) + # Reset accumulators + self._num_decode_tokens = 0 + self._forward_ct_decode = 0 + + def _collect_finished_output(self, req: Req) -> None: + """Build a finished output dict and add it to ``_finished``.""" + decode_ids = req.output_ids[req.read_offset :] + output: Dict[str, Any] = { + "rids": [req.rid], + "finished_reasons": [req.finished_reason], + "decode_ids": decode_ids, + "read_offsets": [req.read_offset], + "output_ids": list(req.output_ids), + "skip_special_tokens": [True], + "prompt_tokens": [req.prompt_len], + "completion_tokens": [len(req.output_ids)], + } + self._finished.append(output) + logger.debug( + "Request %s finished: reason=%s, tokens=%d", + req.rid, + req.finished_reason, + len(req.output_ids), + ) + + def _free_req_resources(self, req: Req) -> None: + """Release KV-cache token budget for a finished request. + + The budget was charged as follows: + - At scheduling: ``+input_len`` (full prompt as conservative estimate) + - After prefix correction: ``-prefix_len`` (cached prefix doesn't need + new KV allocation; model runner manages those via radix cache) + - At each decode step: ``+1`` per generated token + + So the net charge for this request is: + ``(input_len - prefix_len) + num_decode_tokens`` + = ``seq_len - prefix_len`` + + We release exactly that amount. + """ + tokens_to_free = req.seq_len - req.prefix_len + self._used_tokens = max(0, self._used_tokens - tokens_to_free) + req.req_pool_idx = -1 + + def _abort_request(self, rid: str) -> None: + """Abort a request by rid from pending or running queues.""" + # Remove from pending queue + self._pending_queue = [r for r in self._pending_queue if r.rid != rid] + # Abort in running batch + for req in self._running_batch: + if req.rid == rid: + req.abort() + break + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def shutdown(self) -> None: + if self._model_runner is not None: + self._model_runner.shutdown() + for sock in ( + self._recv_from_tokenizer, + self._send_to_detokenizer, + ): + if sock is not None: + sock.close() + if self._zmq_ctx is not None: + self._zmq_ctx.term() + + +def run_scheduler_process( + recv_from_tokenizer_addr: str, + send_to_detokenizer_addr: str, + pipe_writer: Connection, + shared_queue: Optional[TensorQueue] = None, + enable_shared_queue: bool = False, + tensor_transport_mode: TensorTransportMode = "default", + log_level: str = "info", + default_max_new_tokens: int = _DEFAULT_MAX_NEW_TOKENS, + eos_token_ids: Optional[List[int]] = None, + server_config: Optional[Any] = None, + model_config: Optional[Any] = None, + gpu_id: int = 0, +) -> None: + """Entry point for ``torch.multiprocessing.Process(target=...)``. + + The scheduler process now also owns the model runner (sglang-style), + so model initialisation happens here. + """ + setup_subprocess_logging(log_level) + proc = SchedulerProcess( + recv_from_tokenizer_addr, + send_to_detokenizer_addr, + server_config=server_config, + model_config=model_config, + gpu_id=gpu_id, + shared_queue=shared_queue, + enable_shared_queue=enable_shared_queue, + tensor_transport_mode=tensor_transport_mode, + default_max_new_tokens=default_max_new_tokens, + eos_token_ids=eos_token_ids, + ) + proc.init_sockets() + proc.init_model() + + pipe_writer.send({"status": "ready", "process": "scheduler"}) + pipe_writer.close() + + try: + proc.event_loop() + except KeyboardInterrupt: + pass + finally: + proc.shutdown() diff --git a/pymllm/orchestrator/shared_memory_queue.py b/pymllm/orchestrator/shared_memory_queue.py new file mode 100644 index 000000000..2f006bdc0 --- /dev/null +++ b/pymllm/orchestrator/shared_memory_queue.py @@ -0,0 +1,292 @@ +""" +Shared memory and queue utilities for fast IPC between tokenizer and scheduler. + +This module implements the shared-queue fast path to avoid expensive ZMQ +serialization of large multimodal tensors. + +## Design + +- **Metadata lane**: Small tokenized objects are written to a POSIX shared memory + segment keyed by the request ID (``rid``). The scheduler reads and immediately + unlinks the segment. + +- **Tensor lane**: Large tensors can be transported in one of three modes, + controlled by ``TensorTransportMode`` (passed at queue construction time): + + * ``"default"`` – CPU tensors only. GPU tensors are moved to POSIX shared + memory via ``tensor.share_memory_()`` (or left on CPU if already there). + This is the original behaviour and requires no CUDA support. + + * ``"cuda_ipc"`` – GPU tensors stay on GPU and are wrapped in + :class:`~pymllm.orchestrator.cuda_ipc_transport.TransportProxyTensor`. On the + receiver side the proxy's ``__setstate__`` automatically reconstructs the + tensor from the CUDA IPC handle during unpickling. CPU tensors are handled as + in ``"default"`` mode. **Caveat**: GPU memory is not freed until the sender + process exits (PyTorch limitation). Prefer ``"cuda_ipc_pool"`` for services. + + * ``"cuda_ipc_pool"`` – GPU tensors are copied into a pre-allocated + :class:`~pymllm.orchestrator.cuda_ipc_transport.MmItemMemoryPool` workspace and + wrapped in :class:`~pymllm.orchestrator.cuda_ipc_transport.CudaIpcTensorTransportProxy`. + After the receiver copies the data it increments a sync flag and the sender's + recycler thread returns the chunk to the pool. This avoids GPU memory leaks. + CPU tensors are handled as in ``"default"`` mode. + +## Key relationship with CUDA IPC + +``"default"`` and ``"cuda_ipc*"`` modes are **mutually exclusive for GPU tensors**: + +- In ``"default"`` mode, GPU tensors that need to cross process boundaries must + first be moved to CPU (``share_memory_()``). This incurs a GPU→CPU copy. +- In ``"cuda_ipc*"`` modes, GPU tensors are shared as-is via CUDA IPC handles; + no copy to CPU is needed. + +CPU tensors are always handled via ``share_memory_()`` regardless of the mode. +""" + +from __future__ import annotations + +import logging +import pickle +import uuid +from multiprocessing import Queue +from multiprocessing.shared_memory import SharedMemory +from typing import Any, Dict, Literal, Optional + +import torch + +from pymllm.orchestrator.cuda_ipc_transport import ( + MmItemMemoryPool, + TensorTransportMode, + unwrap_mm_inputs_from_ipc, + wrap_mm_inputs_for_ipc, +) + +logger = logging.getLogger(__name__) + + +class SharedMemoryManager: + """Manages shared memory segments for passing metadata between processes. + + Each tokenized request's metadata is written to a unique shared memory + segment keyed by its request ID (rid). The scheduler reads and immediately + unlinks the segment to prevent memory leaks. + """ + + @staticmethod + def write_metadata(rid: str, metadata: Any) -> str: + """Write metadata to shared memory and return the segment name. + + Args: + rid: Request ID (used as part of the shared memory name) + metadata: Serializable metadata object + + Returns: + str: The shared memory segment name + """ + data = pickle.dumps(metadata) + size = len(data) + shm_name = f"pymllm_meta_{rid}_{uuid.uuid4().hex[:8]}" + try: + shm = SharedMemory(name=shm_name, create=True, size=size) + shm.buf[:size] = data + shm.close() + logger.debug("Wrote %d bytes to shared memory %s", size, shm_name) + return shm_name + except Exception as exc: + logger.error("Failed to write metadata to shared memory: %s", exc) + raise + + @staticmethod + def read_metadata(shm_name: str, unlink: bool = True) -> Any: + """Read metadata from shared memory and optionally unlink it. + + Args: + shm_name: The shared memory segment name + unlink: If True, immediately unlink the segment after reading + + Returns: + The deserialized metadata object + """ + try: + shm = SharedMemory(name=shm_name, create=False) + data = bytes(shm.buf[:]) + metadata = pickle.loads(data) + shm.close() + if unlink: + try: + shm.unlink() + logger.debug("Read and unlinked shared memory %s", shm_name) + except FileNotFoundError: + pass + return metadata + except Exception as exc: + logger.error( + "Failed to read metadata from shared memory %s: %s", shm_name, exc + ) + raise + + @staticmethod + def cleanup(shm_name: str) -> None: + """Manually cleanup a shared memory segment (for error recovery).""" + try: + shm = SharedMemory(name=shm_name, create=False) + shm.close() + shm.unlink() + logger.debug("Cleaned up shared memory %s", shm_name) + except FileNotFoundError: + pass + except Exception as exc: + logger.warning("Failed to cleanup shared memory %s: %s", shm_name, exc) + + +class TensorQueue: + """Queue for passing large tensors between processes. + + Depending on ``transport_mode``, GPU tensors are either moved to CPU shared + memory (``"default"``) or kept on GPU and shared via CUDA IPC handles + (``"cuda_ipc"`` / ``"cuda_ipc_pool"``). + + Args: + maxsize: Maximum queue size (0 for unlimited). + transport_mode: Controls how GPU tensors are transported. + pool: Required when ``transport_mode == "cuda_ipc_pool"``. + """ + + def __init__( + self, + maxsize: int = 0, + transport_mode: TensorTransportMode = "default", + pool: Optional[MmItemMemoryPool] = None, + ) -> None: + # pool is allowed to be None at construction time for "cuda_ipc_pool" mode + # because the pool is initialised lazily inside the sender subprocess. + # The pool reference is injected later via _pool attribute assignment. + self._queue: Queue = Queue(maxsize=maxsize) + self._transport_mode = transport_mode + self._pool = pool + + # ------------------------------------------------------------------ + # Producer side + # ------------------------------------------------------------------ + + def put( + self, + rid: str, + shm_name: str, + mm_inputs: Optional[Dict[str, Any]], + ) -> None: + """Put a request into the queue. + + GPU tensors inside *mm_inputs* are wrapped according to + ``transport_mode`` before being placed into the underlying + ``multiprocessing.Queue``. + + Args: + rid: Request ID. + shm_name: Shared memory segment name for the tokenized metadata. + mm_inputs: Multimodal inputs dict (may contain CUDA tensors). + """ + if mm_inputs is not None: + if self._transport_mode in ("cuda_ipc", "cuda_ipc_pool"): + if self._transport_mode == "cuda_ipc_pool" and self._pool is None: + # Pool not yet initialised (race condition or CUDA unavailable); + # fall back to simple CUDA IPC for this message. + effective_mode = "cuda_ipc" + else: + effective_mode = self._transport_mode + # Wrap CUDA tensors in IPC proxies (stays on GPU, no copy to CPU) + mm_inputs = wrap_mm_inputs_for_ipc( + mm_inputs, + transport_mode=effective_mode, + pool=self._pool, + ) + # CPU tensors within mm_inputs are still shared via share_memory_() + mm_inputs = self._share_cpu_tensors(mm_inputs) + else: + # "default": move all tensors to CPU shared memory + mm_inputs = self._make_tensors_shareable(mm_inputs) + + self._queue.put((rid, shm_name, mm_inputs)) + logger.debug("Put request %s into tensor queue (shm=%s)", rid, shm_name) + + # ------------------------------------------------------------------ + # Consumer side + # ------------------------------------------------------------------ + + def get( + self, timeout: Optional[float] = None + ) -> tuple[str, str, Optional[Dict[str, Any]]]: + """Get a request from the queue. + + GPU tensors wrapped as IPC proxies are **not** automatically + reconstructed here – the caller (scheduler) must call + :func:`~pymllm.orchestrator.cuda_ipc_transport.unwrap_mm_inputs_from_ipc` + after retrieval. + + Args: + timeout: Timeout in seconds (None for blocking). + + Returns: + Tuple of ``(rid, shm_name, mm_inputs)``. + """ + rid, shm_name, mm_inputs = self._queue.get(timeout=timeout) + logger.debug("Got request %s from tensor queue (shm=%s)", rid, shm_name) + return rid, shm_name, mm_inputs + + # ------------------------------------------------------------------ + # Queue introspection + # ------------------------------------------------------------------ + + def empty(self) -> bool: + return self._queue.empty() + + def qsize(self) -> int: + try: + return self._queue.qsize() + except NotImplementedError: + return 0 + + def close(self) -> None: + self._queue.close() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _make_tensors_shareable(data: Any) -> Any: + """Recursively move all tensors (CPU and CUDA) to POSIX shared memory. + + GPU tensors are first moved to CPU (incurring a device copy), then + placed in shared memory. This is the ``"default"`` path. + """ + if isinstance(data, torch.Tensor): + if data.is_cuda: + data = data.cpu() + if not data.is_shared(): + data = data.share_memory_() + return data + elif isinstance(data, dict): + return {k: TensorQueue._make_tensors_shareable(v) for k, v in data.items()} + elif isinstance(data, (list, tuple)): + result = [TensorQueue._make_tensors_shareable(item) for item in data] + return type(data)(result) + else: + return data + + @staticmethod + def _share_cpu_tensors(data: Any) -> Any: + """Recursively place CPU tensors in shared memory (GPU tensors are already + wrapped as IPC proxies and must not be touched here). + """ + if isinstance(data, torch.Tensor) and not data.is_cuda: + if not data.is_shared(): + data = data.share_memory_() + return data + elif isinstance(data, dict): + return {k: TensorQueue._share_cpu_tensors(v) for k, v in data.items()} + elif isinstance(data, (list, tuple)): + result = [TensorQueue._share_cpu_tensors(item) for item in data] + return type(data)(result) + else: + return data diff --git a/pymllm/orchestrator/tokenizer_process.py b/pymllm/orchestrator/tokenizer_process.py new file mode 100644 index 000000000..703618a40 --- /dev/null +++ b/pymllm/orchestrator/tokenizer_process.py @@ -0,0 +1,504 @@ +""" +TokenizerProcess -- subprocess that tokenizes incoming raw requests. + +Receives raw requests from RequestResponseProcess via ZMQ, tokenizes them, +and forwards the tokenized payloads to the SchedulerProcess. + +Supports two transport modes (controlled by ``enable_shared_queue`` and +``tensor_transport_mode`` in the tokenizer config): + +1. **Legacy ZMQ path** (``enable_shared_queue=False``): + Tokenized objects are sent directly via ``ZMQ send_pyobj`` (pickle). This + is simple but slow for large multimodal tensors. + +2. **Shared queue fast path** (``enable_shared_queue=True``): + Metadata is written to POSIX shared memory and the queue carries a + lightweight ``(rid, shm_name, mm_inputs)`` tuple. The GPU tensors inside + ``mm_inputs`` are transported differently depending on ``tensor_transport_mode``: + + * ``"default"`` – GPU tensors are moved to CPU first (GPU→CPU copy), + then placed in POSIX shared memory. + * ``"cuda_ipc"`` – GPU tensors stay on GPU; they are wrapped in a + :class:`TransportProxyTensor` whose pickle uses CUDA IPC handles. + Simple but may leak GPU memory. + * ``"cuda_ipc_pool"`` – GPU tensors are copied into a pre-allocated + :class:`MmItemMemoryPool` workspace and shared via pool-chunk IPC + handles. Chunks are recycled; no GPU memory is leaked. +""" + +import logging +from multiprocessing.connection import Connection +from typing import Any, Dict, List, Optional, Union + +import zmq +from transformers import AutoProcessor, AutoTokenizer + +from pymllm.engine.io_struct import TokenizedGenerateReqInput +from pymllm.orchestrator.cuda_ipc_transport import MmItemMemoryPool, TensorTransportMode +from pymllm.orchestrator.ipc_utils import create_zmq_socket, setup_subprocess_logging +from pymllm.orchestrator.shared_memory_queue import SharedMemoryManager, TensorQueue + +logger = logging.getLogger(__name__) + + +class TokenizerProcess: + """Runs inside a subprocess spawned by ``torch.multiprocessing``.""" + + def __init__( + self, + recv_from_rr_addr: str, + send_to_scheduler_addr: str, + tokenizer_cfg: Dict[str, Any], + shared_queue: Optional[TensorQueue] = None, + ): + """ + Parameters + ---------- + tokenizer_cfg: + Serialisable dict built by the parent process (``Engine``) before + spawning. Required keys: + + * ``tokenizer_path`` – str, path to the tokenizer directory. + * ``tokenizer_mode`` – ``"auto" | "slow" | "fast"``. + * ``trust_remote_code`` – bool. + * ``context_length`` – Optional[int], explicit cap; inferred + from ``hf_config`` when ``None``. + * ``hf_config`` – Optional HuggingFace PretrainedConfig. + * ``enable_shared_queue`` – bool, whether to use shared memory fast path. + * ``tensor_transport_mode`` – ``"default" | "cuda_ipc" | "cuda_ipc_pool"``. + * ``cuda_ipc_pool_size_mb`` – int, pool size in MB (cuda_ipc_pool only). + * ``cuda_ipc_recycle_interval`` – float, recycler sleep interval (s). + + shared_queue: + Optional :class:`TensorQueue` for the shared memory fast path. + When *transport_mode* is ``"cuda_ipc_pool"`` this queue should have + been constructed with a ``MmItemMemoryPool``; the ``TokenizerProcess`` + initialises its own pool in that case. + """ + self._recv_from_rr_addr = recv_from_rr_addr + self._send_to_scheduler_addr = send_to_scheduler_addr + self._tokenizer_cfg = tokenizer_cfg + self._enable_shared_queue = tokenizer_cfg.get("enable_shared_queue", False) + self._shared_queue = shared_queue + + # Tensor transport configuration + self._transport_mode: TensorTransportMode = tokenizer_cfg.get( + "tensor_transport_mode", "default" + ) + # Pool for cuda_ipc_pool mode – will be initialised lazily when the + # process first encounters a CUDA tensor. + self._ipc_pool: Optional[MmItemMemoryPool] = None + if self._transport_mode == "cuda_ipc_pool": + # The pool must be created inside the subprocess (after fork/spawn) + # because it allocates CUDA memory. We defer to _ensure_pool(). + pool_mb: int = int(tokenizer_cfg.get("cuda_ipc_pool_size_mb", 512)) + recycle: float = float(tokenizer_cfg.get("cuda_ipc_recycle_interval", 0.1)) + self._ipc_pool_size_mb = pool_mb + self._ipc_recycle_interval = recycle + + self._zmq_ctx: Optional[zmq.Context] = None + self._recv_from_rr: Optional[zmq.Socket] = None + self._send_to_scheduler: Optional[zmq.Socket] = None + + self._tokenizer = None + self._mm_processor = None + self._context_length: Optional[int] = None + + self._init_tokenizers() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def init_sockets(self) -> None: + self._zmq_ctx = zmq.Context() + self._recv_from_rr = create_zmq_socket( + self._zmq_ctx, + zmq.PULL, + self._recv_from_rr_addr, + bind=False, + ) + self._send_to_scheduler = create_zmq_socket( + self._zmq_ctx, + zmq.PUSH, + self._send_to_scheduler_addr, + bind=True, + ) + + def event_loop(self) -> None: + """Infinite loop: recv raw request -> tokenize -> send to scheduler.""" + logger.info( + "TokenizerProcess event loop started (shared_queue=%s, transport=%s)", + self._enable_shared_queue, + self._transport_mode, + ) + while True: + raw_request: Dict[str, Any] = self._recv_from_rr.recv_pyobj() + tokenized = self._tokenize(raw_request) + + if self._enable_shared_queue and self._shared_queue is not None: + # Shared queue fast path + self._send_via_shared_queue(tokenized) + else: + # Legacy ZMQ path + self._send_to_scheduler.send_pyobj(tokenized) + + def _send_via_shared_queue( + self, tokenized: Union[TokenizedGenerateReqInput, Dict[str, Any]] + ) -> None: + """Send tokenized request via shared memory + shared queue fast path. + + GPU tensors inside ``mm_inputs`` are handled according to + ``self._transport_mode``: + + * ``"default"`` – moved to CPU via ``share_memory_()`` by ``TensorQueue``. + * ``"cuda_ipc"`` – wrapped in :class:`TransportProxyTensor` (stays on GPU). + * ``"cuda_ipc_pool"`` – copied into the :class:`MmItemMemoryPool` workspace and + wrapped in :class:`CudaIpcTensorTransportProxy`. + + Abort sentinel messages are forwarded via ZMQ (they are lightweight dicts). + """ + # Handle abort sentinel + if isinstance(tokenized, dict) and tokenized.get("abort"): + # Fallback to ZMQ for abort messages (no tensor payload) + self._send_to_scheduler.send_pyobj(tokenized) + return + + assert isinstance(tokenized, TokenizedGenerateReqInput), ( + f"Expected TokenizedGenerateReqInput, got {type(tokenized)}" + ) + + # Lazily initialise the CUDA IPC pool (must happen inside the subprocess) + if self._transport_mode == "cuda_ipc_pool": + self._ensure_pool() + + rid = tokenized.rid + mm_inputs = tokenized.mm_inputs + + # Create lightweight metadata object (mm_inputs sent separately via queue) + metadata = TokenizedGenerateReqInput( + rid=tokenized.rid, + input_text=tokenized.input_text, + input_ids=tokenized.input_ids, + mm_inputs=None, # Will be passed separately via shared queue + sampling_params=tokenized.sampling_params, + stream=tokenized.stream, + return_logprob=tokenized.return_logprob, + logprob_start_len=tokenized.logprob_start_len, + top_logprobs_num=tokenized.top_logprobs_num, + lora_path=tokenized.lora_path, + session_params=tokenized.session_params, + ) + + # Write metadata to shared memory + shm_name = SharedMemoryManager.write_metadata(rid, metadata) + + # Put (rid, shm_name, mm_inputs) into shared queue + # TensorQueue.put() handles wrapping mm_inputs based on transport_mode + self._shared_queue.put(rid, shm_name, mm_inputs) + + logger.debug( + "Sent request %s via shared queue (shm=%s, transport=%s)", + rid, + shm_name, + self._transport_mode, + ) + + # ------------------------------------------------------------------ + # CUDA IPC pool initialisation (deferred to subprocess) + # ------------------------------------------------------------------ + + def _ensure_pool(self) -> None: + """Lazily create the MmItemMemoryPool inside the subprocess. + + This is deferred because CUDA context creation must happen after + ``torch.multiprocessing.Process`` has started (post-fork/spawn). + Once the pool is created we update the shared queue's transport config + in-place so the same underlying ``multiprocessing.Queue`` object is reused + (both processes already hold a reference to it). + """ + if self._ipc_pool is not None: + return + try: + import torch + + if not torch.cuda.is_available(): + logger.warning( + "CUDA not available; falling back to transport_mode='default'" + ) + self._transport_mode = "default" + if self._shared_queue is not None: + self._shared_queue._transport_mode = "default" + return + + pool_bytes = self._ipc_pool_size_mb * 1024 * 1024 + device = torch.cuda.current_device() + self._ipc_pool = MmItemMemoryPool( + memory_size=pool_bytes, + recycle_interval=self._ipc_recycle_interval, + device=device, + ) + # Update the shared queue's config in-place. + # Both processes share the same multiprocessing.Queue object, so we + # just update the wrapper's transport metadata; the underlying queue + # pipe is unchanged. + if self._shared_queue is not None: + self._shared_queue._transport_mode = self._transport_mode + self._shared_queue._pool = self._ipc_pool + + logger.info( + "MmItemMemoryPool initialised: %d MB on cuda:%d", + self._ipc_pool_size_mb, + device, + ) + except Exception as exc: + logger.error( + "Failed to initialise MmItemMemoryPool: %s; " + "falling back to transport_mode='default'", + exc, + exc_info=True, + ) + self._transport_mode = "default" + if self._shared_queue is not None: + self._shared_queue._transport_mode = "default" + + # ------------------------------------------------------------------ + # Tokenization and multimodal preprocessing + # ------------------------------------------------------------------ + + def _init_tokenizers(self) -> None: + """Initialise text tokenizer and (optionally) multimodal processor. + + All configuration is read from ``self._tokenizer_cfg`` which was + serialised by the parent process before ``spawn``. No global config + access happens inside the subprocess. + """ + cfg = self._tokenizer_cfg + tokenizer_path: str = cfg["tokenizer_path"] + tokenizer_mode: str = cfg.get("tokenizer_mode", "auto") + trust_remote_code: bool = bool(cfg.get("trust_remote_code", False)) + + tokenizer_kwargs: Dict[str, Any] = { + "use_fast": tokenizer_mode != "slow", + "trust_remote_code": trust_remote_code, + } + + self._tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, + **tokenizer_kwargs, + ) + + # Default to left padding for generation. + try: + self._tokenizer.padding_side = "left" + except Exception: + pass + + # Context length: explicit config value takes priority; fall back to + # common HF config field names. + context_len: Optional[int] = cfg.get("context_length") + if context_len is None: + hf_cfg = cfg.get("hf_config") + for name in ("max_position_embeddings", "max_sequence_length", "seq_len"): + if hf_cfg is not None and hasattr(hf_cfg, name): + context_len = int(getattr(hf_cfg, name)) + break + self._context_length = context_len + + # Try to load multimodal processor (optional). + try: + self._mm_processor = AutoProcessor.from_pretrained( + tokenizer_path, + trust_remote_code=trust_remote_code, + ) + except Exception: + # Text-only models don't provide a processor; that's fine. + self._mm_processor = None + + def _tokenize( + self, raw_request: Dict[str, Any] + ) -> Union[TokenizedGenerateReqInput, Dict[str, Any]]: + """Tokenize one raw request dict and return a typed object. + + * **Abort** messages (``{"rid": ..., "abort": True}``) are returned as + plain dicts so the scheduler can intercept them without importing the + io_struct. + * Normal requests are returned as a :class:`TokenizedGenerateReqInput` + dataclass instance that carries ``input_ids``, ``mm_inputs``, and all + sampling meta-data in typed fields. + + Each message arriving here corresponds to exactly one sub-request + because batch splitting happens upstream in ``RequestResponseProcess``. + """ + # Abort: propagate as a plain sentinel dict. + if raw_request.get("abort"): + return {"rid": raw_request.get("rid"), "abort": True} + + # ------------------------------------------------------------------ # + # 1. Text tokenization + # ------------------------------------------------------------------ # + if raw_request.get("input_ids") is not None: + # Caller already tokenized – skip text processing. + input_ids: List[int] = list(raw_request["input_ids"]) + raw_text = raw_request.get("text") + input_text: str = ( + str(raw_text[0]) if isinstance(raw_text, list) else str(raw_text or "") + ) + else: + text = raw_request.get("text") + if text is None: + raise ValueError( + "TokenizerProcess expects either `text` or `input_ids`." + ) + # Accept a list for robustness; take the first element. + input_text = str(text[0]) if isinstance(text, list) else str(text) + logger.debug(f"Tokenizing input text {input_text}") + + encode_kwargs: Dict[str, Any] = { + "add_special_tokens": True, + "return_attention_mask": False, + } + if self._context_length is not None: + encode_kwargs.update( + {"truncation": True, "max_length": self._context_length} + ) + + encoding = self._tokenizer(input_text, **encode_kwargs) + input_ids = encoding["input_ids"] + + # ------------------------------------------------------------------ # + # 2. Multimodal pre-processing + # ------------------------------------------------------------------ # + mm_inputs = self._collect_mm_inputs(raw_request, text=input_text) + + # ------------------------------------------------------------------ # + # 3. Pack into the typed dataclass + # ------------------------------------------------------------------ # + return TokenizedGenerateReqInput( + rid=raw_request.get("rid"), + input_text=input_text, + input_ids=input_ids, + mm_inputs=mm_inputs, + sampling_params=raw_request.get("sampling_params") or {}, + stream=bool(raw_request.get("stream", False)), + return_logprob=bool(raw_request.get("return_logprob", False)), + logprob_start_len=int(raw_request.get("logprob_start_len", -1)), + top_logprobs_num=int(raw_request.get("top_logprobs_num", 0)), + lora_path=raw_request.get("lora_path"), + session_params=raw_request.get("session_params"), + ) + + def _normalize_image_input(self, image_data: Any) -> List[Any]: + """Normalise ``image_data`` into a list of image-like objects. + + Supported input forms: + - single PIL.Image / numpy array / torch.Tensor + - path string or bytes + - list/tuple of the above + """ + + def _to_image(obj: Any) -> Any: + # Lazily import Pillow to avoid hard dependency for text-only models. + try: + from PIL import Image # type: ignore + except Exception as exc: # pragma: no cover - optional dependency + raise RuntimeError( + "Pillow is required for image preprocessing in TokenizerProcess" + ) from exc + + if obj is None: + return None + if isinstance(obj, Image.Image): + return obj + if isinstance(obj, (str, bytes)): + return Image.open(obj) + return obj + + if isinstance(image_data, (list, tuple)): + return [ + img for img in (_to_image(x) for x in image_data) if img is not None + ] + return [img for img in (_to_image(image_data),) if img is not None] + + def _collect_mm_inputs( + self, raw_request: Dict[str, Any], text: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """Pre-process multimodal data and return a consolidated ``mm_inputs`` dict. + + Returns ``None`` for text-only requests. Otherwise returns a flat dict + whose keys are ready to be unpacked by the model runner: + + * ``image_inputs`` – output of ``AutoProcessor`` (contains + ``pixel_values``, etc.) when a processor is available. + * ``image_data`` – raw image objects when no processor is available. + * ``audio_data`` – forwarded verbatim (no processor yet). + * ``video_data`` – forwarded verbatim (no processor yet). + """ + image_data = raw_request.get("image_data") + video_data = raw_request.get("video_data") + audio_data = raw_request.get("audio_data") + + if not any(x is not None for x in (image_data, video_data, audio_data)): + return None # text-only request + + mm: Dict[str, Any] = {} + + # Image: prefer AutoProcessor output; fall back to raw data. + if image_data is not None: + if self._mm_processor is not None: + images = self._normalize_image_input(image_data) + try: + processor_inputs = self._mm_processor( + images=images, + text=text if text is not None else raw_request.get("text"), + return_tensors="pt", + ) + mm["image_inputs"] = processor_inputs + except Exception: + mm["image_data"] = image_data + else: + mm["image_data"] = image_data + + # Audio / video forwarded verbatim for now. + if audio_data is not None: + mm["audio_data"] = audio_data + if video_data is not None: + mm["video_data"] = video_data + + return mm + + def shutdown(self) -> None: + if self._ipc_pool is not None: + self._ipc_pool.shutdown() + if self._recv_from_rr is not None: + self._recv_from_rr.close() + if self._send_to_scheduler is not None: + self._send_to_scheduler.close() + if self._zmq_ctx is not None: + self._zmq_ctx.term() + + +def run_tokenizer_process( + recv_from_rr_addr: str, + send_to_scheduler_addr: str, + pipe_writer: Connection, + tokenizer_cfg: Dict[str, Any], + shared_queue: Optional[TensorQueue] = None, +) -> None: + """Entry point for ``torch.multiprocessing.Process(target=...)``.""" + setup_subprocess_logging(tokenizer_cfg.get("log_level", "info")) + proc = TokenizerProcess( + recv_from_rr_addr, send_to_scheduler_addr, tokenizer_cfg, shared_queue + ) + proc.init_sockets() + + # Signal readiness to the parent process + pipe_writer.send({"status": "ready", "process": "tokenizer"}) + pipe_writer.close() + + try: + proc.event_loop() + except KeyboardInterrupt: + pass + finally: + proc.shutdown() diff --git a/pymllm/parsers/__init__.py b/pymllm/parsers/__init__.py new file mode 100644 index 000000000..5ac5c2922 --- /dev/null +++ b/pymllm/parsers/__init__.py @@ -0,0 +1,10 @@ +"""Output parsers for reasoning (thinking) content and tool calls.""" + +from pymllm.parsers.reasoning_parser import ReasoningParser +from pymllm.parsers.tool_call_parser import ToolCallParser, ToolCallItem + +__all__ = [ + "ReasoningParser", + "ToolCallParser", + "ToolCallItem", +] diff --git a/pymllm/parsers/reasoning_parser.py b/pymllm/parsers/reasoning_parser.py new file mode 100644 index 000000000..1f73c7885 --- /dev/null +++ b/pymllm/parsers/reasoning_parser.py @@ -0,0 +1,212 @@ +"""Reasoning / thinking content parser. + +Separates ``...`` (or model-specific markers) from normal +assistant content. Supports both one-shot and incremental streaming modes. + +Usage:: + + # Non-streaming + parser = ReasoningParser("qwen3") + reasoning, content = parser.parse_non_stream(full_text) + + # Streaming + parser = ReasoningParser("qwen3") + for delta in deltas: + reasoning_delta, content_delta = parser.parse_stream_chunk(delta) +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Type + + +# --------------------------------------------------------------------------- +# Detector registry +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _DetectorConfig: + start: str + end: str + force: bool # True = always assume reasoning at start + + +_DETECTOR_MAP: Dict[str, _DetectorConfig] = { + # DeepSeek-R1: always starts in reasoning mode + "deepseek-r1": _DetectorConfig("", "", force=True), + # Qwen3: optional thinking (controlled by request) + "qwen3": _DetectorConfig("", "", force=False), + # Qwen3 forced thinking + "qwen3-thinking": _DetectorConfig("", "", force=True), + # GLM-4.5 + "glm45": _DetectorConfig("", "", force=False), + # Kimi + "kimi": _DetectorConfig("\u25c1think\u25b7", "\u25c1/think\u25b7", force=False), +} + + +# --------------------------------------------------------------------------- +# ReasoningParser +# --------------------------------------------------------------------------- + + +class ReasoningParser: + """Model-agnostic reasoning content parser. + + Parameters + ---------- + model_type + Key into the detector registry (e.g. ``"qwen3"``, ``"deepseek-r1"``). + stream_reasoning + If ``True``, stream reasoning content incrementally as it arrives. + If ``False``, buffer reasoning until the end tag is found. + """ + + SUPPORTED = set(_DETECTOR_MAP) + + def __init__(self, model_type: str, stream_reasoning: bool = True): + cfg = _DETECTOR_MAP.get(model_type) + if cfg is None: + raise ValueError( + f"Unknown reasoning parser {model_type!r}. " + f"Supported: {sorted(_DETECTOR_MAP)}" + ) + self._start = cfg.start + self._end = cfg.end + self._force = cfg.force + self._stream_reasoning = stream_reasoning + + # -- streaming state -- + self._buffer = "" + self._in_reasoning = cfg.force + self._start_consumed = False # True once start tag has been stripped + self._done = False # True once end tag has been seen + + # ------------------------------------------------------------------ # + # Non-streaming + # ------------------------------------------------------------------ # + + def parse_non_stream(self, text: str) -> Tuple[Optional[str], str]: + """Parse complete text. + + Returns ``(reasoning_content, content)`` where either may be empty. + """ + start_idx = text.find(self._start) + end_idx = text.find(self._end) + + if start_idx == -1 and not self._force: + return None, text + + # Determine boundaries + if self._force and start_idx == -1: + # Model didn't emit explicit start tag; treat prefix as reasoning + reason_start = 0 + else: + reason_start = start_idx + len(self._start) + + before = text[:start_idx] if start_idx != -1 else "" + + if end_idx != -1 and end_idx >= reason_start: + reasoning = text[reason_start:end_idx] + after = text[end_idx + len(self._end) :] + else: + reasoning = text[reason_start:] + after = "" + + content = (before + after).strip() + reasoning = reasoning.strip() + return reasoning or None, content + + # ------------------------------------------------------------------ # + # Streaming + # ------------------------------------------------------------------ # + + def parse_stream_chunk(self, delta: str) -> Tuple[str, str]: + """Parse an incremental streaming delta. + + Returns ``(reasoning_delta, content_delta)``. Either may be ``""``. + """ + if not delta: + return "", "" + + if self._done: + return "", delta + + self._buffer += delta + reasoning_out = "" + content_out = "" + + # In forced reasoning mode, consume the start tag if it appears + # (the model may or may not emit it explicitly). + if self._in_reasoning and not self._start_consumed: + idx = self._buffer.find(self._start) + if idx != -1: + # Start tag found — strip it and any text before it + self._buffer = self._buffer[idx + len(self._start) :] + self._start_consumed = True + elif _could_be_partial(self._buffer, self._start): + # Might be a partial start tag — hold the buffer + return "", "" + else: + # No start tag coming — mark consumed and continue + self._start_consumed = True + + if not self._in_reasoning: + # --- look for start tag --- + idx = self._buffer.find(self._start) + if idx != -1: + content_out += self._buffer[:idx] + self._buffer = self._buffer[idx + len(self._start) :] + self._in_reasoning = True + self._start_consumed = True + elif _could_be_partial(self._buffer, self._start): + # Potential partial match at tail — hold the buffer + safe = len(self._buffer) - len(self._start) + 1 + if safe > 0: + content_out += self._buffer[:safe] + self._buffer = self._buffer[safe:] + return "", content_out + else: + content_out += self._buffer + self._buffer = "" + return "", content_out + + if self._in_reasoning: + # --- look for end tag --- + idx = self._buffer.find(self._end) + if idx != -1: + reasoning_out += self._buffer[:idx] + after = self._buffer[idx + len(self._end) :] + self._buffer = "" + self._in_reasoning = False + self._done = True + if after: + content_out += after + elif _could_be_partial(self._buffer, self._end): + safe = len(self._buffer) - len(self._end) + 1 + if safe > 0: + reasoning_out += self._buffer[:safe] + self._buffer = self._buffer[safe:] + else: + reasoning_out += self._buffer + self._buffer = "" + + if not self._stream_reasoning: + reasoning_out = "" + + return reasoning_out, content_out + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _could_be_partial(text: str, pattern: str) -> bool: + """Return True if *text* ends with a prefix of *pattern*.""" + for i in range(1, len(pattern)): + if text.endswith(pattern[:i]): + return True + return False diff --git a/pymllm/parsers/tool_call_parser.py b/pymllm/parsers/tool_call_parser.py new file mode 100644 index 000000000..fdfe93914 --- /dev/null +++ b/pymllm/parsers/tool_call_parser.py @@ -0,0 +1,433 @@ +"""Tool-call (function-calling) output parser. + +Extracts structured tool calls from model output text. Supports both +one-shot and incremental streaming modes. + +Formats supported: + +* **qwen25** — ``{"name":...,"arguments":...}`` +* **llama3** — ``<|python_tag|>{"name":...,"parameters":...}`` +* **hermes** — ``{"name":...,"arguments":...}`` (same tags, Hermes schema) + +Usage:: + + # Non-streaming + parser = ToolCallParser("qwen25", tools=tools_list) + content, tool_calls = parser.parse_non_stream(full_text) + + # Streaming + parser = ToolCallParser("qwen25", tools=tools_list) + for delta in deltas: + content_delta, tool_call_deltas = parser.parse_stream_chunk(delta) +""" + +from __future__ import annotations + +import json +import re +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +@dataclass +class ToolCallItem: + """A single parsed tool call.""" + + name: Optional[str] = None + arguments: str = "" + tool_call_id: str = "" + index: int = 0 + + def to_openai_dict(self, streaming: bool = True) -> Dict[str, Any]: + """Convert to OpenAI ``tool_calls[]`` element format. + + Parameters + ---------- + streaming + If True, include ``index`` (streaming delta format). + If False, omit ``index`` (non-streaming message format). + """ + d: Dict[str, Any] = {"type": "function", "function": {}} + if streaming: + d["index"] = self.index + if self.tool_call_id: + d["id"] = self.tool_call_id + fn: Dict[str, Any] = d["function"] + if self.name is not None: + fn["name"] = self.name + fn["arguments"] = self.arguments or "" + return d + + +# --------------------------------------------------------------------------- +# Detector base +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _FormatConfig: + bot_token: str + end_token: str + # Regex to extract individual call bodies between bot/end tokens. + # If None, the entire text between bot and end tokens is one call. + call_regex: Optional[str] = None + + +_FORMAT_MAP: Dict[str, _FormatConfig] = { + "qwen25": _FormatConfig( + bot_token="\n", + end_token="\n", + ), + "qwen3_coder": _FormatConfig( + bot_token="", + end_token="", + ), + "hermes": _FormatConfig( + bot_token="\n", + end_token="\n", + ), + "llama3": _FormatConfig( + bot_token="<|python_tag|>", + end_token="", # Llama3 uses EOT, detected via EOS + ), +} + + +# --------------------------------------------------------------------------- +# ToolCallParser +# --------------------------------------------------------------------------- + + +class ToolCallParser: + """Model-agnostic tool-call parser. + + Parameters + ---------- + model_type + Key into the format registry (e.g. ``"qwen25"``, ``"llama3"``). + tools + The ``tools`` list from the OpenAI chat request (used to resolve + function names). + """ + + SUPPORTED = set(_FORMAT_MAP) + + def __init__(self, model_type: str, tools: Optional[List[Any]] = None): + cfg = _FORMAT_MAP.get(model_type) + if cfg is None: + raise ValueError( + f"Unknown tool-call parser {model_type!r}. " + f"Supported: {sorted(_FORMAT_MAP)}" + ) + self._bot = cfg.bot_token + self._end = cfg.end_token + self._model_type = model_type + self._tools = tools or [] + + # -- streaming state -- + self._buffer = "" + self._in_call = False + self._current_tool_idx = 0 + self._current_call_buf = "" + self._prev_args_len = 0 + self._name_sent = False + self._completed_calls: List[ToolCallItem] = [] + + # ------------------------------------------------------------------ # + # Non-streaming + # ------------------------------------------------------------------ # + + def has_tool_call(self, text: str) -> bool: + """Return True if *text* contains a tool-call pattern.""" + return self._bot in text + + def parse_non_stream( + self, text: str + ) -> Tuple[str, List[ToolCallItem]]: + """Parse complete text. + + Returns ``(remaining_content, tool_calls)``. + """ + if not self.has_tool_call(text): + return text, [] + + tool_calls: List[ToolCallItem] = [] + normal_parts: List[str] = [] + + remaining = text + idx = 0 + while True: + bot_pos = remaining.find(self._bot) + if bot_pos == -1: + normal_parts.append(remaining) + break + normal_parts.append(remaining[:bot_pos]) + remaining = remaining[bot_pos + len(self._bot) :] + + if self._end: + end_pos = remaining.find(self._end) + if end_pos == -1: + call_body = remaining + remaining = "" + else: + call_body = remaining[:end_pos] + remaining = remaining[end_pos + len(self._end) :] + else: + call_body = remaining + remaining = "" + + parsed = self._parse_call_body(call_body.strip()) + if parsed is not None: + parsed.index = idx + parsed.tool_call_id = _make_tool_call_id() + tool_calls.append(parsed) + idx += 1 + + content = "".join(normal_parts).strip() + return content, tool_calls + + # ------------------------------------------------------------------ # + # Streaming + # ------------------------------------------------------------------ # + + def parse_stream_chunk( + self, delta: str + ) -> Tuple[str, List[ToolCallItem]]: + """Parse an incremental streaming delta. + + Returns ``(content_delta, tool_call_items)``. + + For tool call items: + - First item for a call: ``name`` is set, ``arguments`` is ``""``. + - Subsequent items: ``name`` is ``None``, ``arguments`` is the new + characters appended (argument delta). + """ + if not delta: + return "", [] + + self._buffer += delta + content_out = "" + items: List[ToolCallItem] = [] + + while True: + if not self._in_call: + # --- look for bot token --- + bot_pos = self._buffer.find(self._bot) + if bot_pos != -1: + content_out += self._buffer[:bot_pos] + self._buffer = self._buffer[bot_pos + len(self._bot) :] + self._in_call = True + self._current_call_buf = "" + self._prev_args_len = 0 + self._name_sent = False + continue # try to process call content + else: + # Check for partial bot token at tail + if self._bot and _could_be_partial(self._buffer, self._bot): + safe = len(self._buffer) - len(self._bot) + 1 + if safe > 0: + content_out += self._buffer[:safe] + self._buffer = self._buffer[safe:] + else: + content_out += self._buffer + self._buffer = "" + break + + if self._in_call: + # --- look for end token --- + if self._end: + end_pos = self._buffer.find(self._end) + if end_pos != -1: + self._current_call_buf += self._buffer[:end_pos] + self._buffer = self._buffer[end_pos + len(self._end) :] + # Emit final tool call + item = self._finalize_call() + if item is not None: + items.append(item) + self._in_call = False + self._current_tool_idx += 1 + continue # there may be more calls + else: + # Accumulate and stream arguments + self._current_call_buf += self._buffer + self._buffer = "" + item = self._stream_partial_call() + if item is not None: + items.append(item) + break + else: + # No end token (e.g. Llama3) — accumulate everything + self._current_call_buf += self._buffer + self._buffer = "" + item = self._stream_partial_call() + if item is not None: + items.append(item) + break + + return content_out, items + + def flush(self) -> List[ToolCallItem]: + """Flush any remaining buffered tool call (call at request end).""" + items: List[ToolCallItem] = [] + if self._in_call and self._current_call_buf.strip(): + item = self._finalize_call() + if item is not None: + items.append(item) + self._in_call = False + return items + + # ------------------------------------------------------------------ # + # Internal helpers + # ------------------------------------------------------------------ # + + def _parse_call_body(self, body: str) -> Optional[ToolCallItem]: + """Parse a single call body (JSON or qwen3_coder XML-style).""" + if self._model_type == "qwen3_coder": + return self._parse_qwen3_coder_body(body) + try: + obj = json.loads(body) + except json.JSONDecodeError: + return None + name = obj.get("name") + args = obj.get("arguments") or obj.get("parameters") or {} + if isinstance(args, dict): + args = json.dumps(args, ensure_ascii=False) + return ToolCallItem(name=name, arguments=args) + + @staticmethod + def _parse_qwen3_coder_body(body: str) -> Optional[ToolCallItem]: + """Parse qwen3_coder XML-style: ``V...``.""" + # Extract function name + func_m = re.search(r"]+)>", body) + if func_m is None: + return None + name = func_m.group(1) + # Extract parameters + params: Dict[str, Any] = {} + for pm in re.finditer( + r"]+)>(.*?)(?:|(?=))", + body, + re.DOTALL, + ): + key = pm.group(1) + val = pm.group(2).strip() + # Try to parse as JSON value, otherwise keep as string + try: + params[key] = json.loads(val) + except (json.JSONDecodeError, ValueError): + params[key] = val + return ToolCallItem( + name=name, + arguments=json.dumps(params, ensure_ascii=False), + ) + + def _stream_partial_call(self) -> Optional[ToolCallItem]: + """Try to extract streaming information from the partial call.""" + body = self._current_call_buf.strip() + if not body: + return None + + # Try to extract name first + if not self._name_sent: + name = self._try_extract_name(body) + if name is not None: + self._name_sent = True + return ToolCallItem( + name=name, + arguments="", + tool_call_id=_make_tool_call_id(), + index=self._current_tool_idx, + ) + return None + + # Stream argument characters + args_str = self._try_extract_args_partial(body) + if args_str is not None and len(args_str) > self._prev_args_len: + new_chars = args_str[self._prev_args_len :] + self._prev_args_len = len(args_str) + return ToolCallItem( + name=None, + arguments=new_chars, + index=self._current_tool_idx, + ) + return None + + def _finalize_call(self) -> Optional[ToolCallItem]: + """Finalize a complete call — emit any remaining argument chars.""" + parsed = self._parse_call_body(self._current_call_buf.strip()) + if parsed is None: + return None + + if not self._name_sent: + # Entire call came at once + parsed.index = self._current_tool_idx + parsed.tool_call_id = _make_tool_call_id() + return parsed + + # Name was already sent — emit remaining arguments + full_args = parsed.arguments + new_chars = full_args[self._prev_args_len :] + if new_chars: + return ToolCallItem( + name=None, + arguments=new_chars, + index=self._current_tool_idx, + ) + return None + + def _try_extract_name(self, partial: str) -> Optional[str]: + """Try to extract function name from partial call body.""" + if self._model_type == "qwen3_coder": + m = re.search(r"]+)>", partial) + return m.group(1) if m else None + m = re.search(r'"name"\s*:\s*"([^"]+)"', partial) + return m.group(1) if m else None + + def _try_extract_args_partial(self, partial: str) -> Optional[str]: + """Try to extract partial arguments from call body.""" + if self._model_type == "qwen3_coder": + # Build JSON incrementally from V tags + params: Dict[str, Any] = {} + for pm in re.finditer( + r"]+)>(.*?)(?:)", + partial, + re.DOTALL, + ): + key = pm.group(1) + val = pm.group(2).strip() + try: + params[key] = json.loads(val) + except (json.JSONDecodeError, ValueError): + params[key] = val + if params: + return json.dumps(params, ensure_ascii=False) + return None + m = re.search(r'"arguments"\s*:\s*(\{.*)', partial, re.DOTALL) + if m: + return m.group(1) + m = re.search(r'"parameters"\s*:\s*(\{.*)', partial, re.DOTALL) + if m: + return m.group(1) + return None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_tool_call_id() -> str: + return f"call_{uuid.uuid4().hex[:24]}" + + +def _could_be_partial(text: str, pattern: str) -> bool: + for i in range(1, len(pattern)): + if text.endswith(pattern[:i]): + return True + return False diff --git a/pymllm/quantization/__init__.py b/pymllm/quantization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymllm/quantization/methods/__init__.py b/pymllm/quantization/methods/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymllm/quantization/methods/awq_w4a16.py b/pymllm/quantization/methods/awq_w4a16.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymllm/quantization/quant_recipe.py b/pymllm/quantization/quant_recipe.py new file mode 100644 index 000000000..a5b493bec --- /dev/null +++ b/pymllm/quantization/quant_recipe.py @@ -0,0 +1,3 @@ +class QuantRecipe: + def __init__(self): + pass diff --git a/pymllm/server/__init__.py b/pymllm/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymllm/server/launch.py b/pymllm/server/launch.py new file mode 100644 index 000000000..b9f603220 --- /dev/null +++ b/pymllm/server/launch.py @@ -0,0 +1,936 @@ +"""pymllm HTTP server -- RESTful API entry point. + +This module implements a FastAPI-based HTTP server that wraps the pymllm +:class:`Engine` and exposes OpenAI-compatible and native REST endpoints, +following the architecture of sglang's ``http_server.py``. + +Endpoints +--------- +* ``GET /health`` -- liveness probe +* ``GET /v1/models`` -- list served models (OpenAI-compatible) +* ``POST /generate`` -- native generate (streaming via SSE) +* ``POST /v1/completions`` -- OpenAI-compatible completions +* ``POST /v1/chat/completions`` -- OpenAI-compatible chat completions +* ``GET /model_info`` -- model metadata +* ``GET /server_info`` -- runtime config dump +* ``POST /flush_cache`` -- flush internal caches +* ``POST /abort_request`` -- cancel a running request +""" + +import asyncio +import logging +import os +import time +import uuid +from contextlib import asynccontextmanager +from typing import Any, AsyncIterator, Dict, List, Optional, Union + +import orjson +import uvicorn +import uvloop +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import ORJSONResponse, Response, StreamingResponse +from pydantic import BaseModel, Field + +from pymllm.configs.global_config import get_global_config, make_args, read_args +from pymllm.engine.launch import Engine + +logger = logging.getLogger(__name__) +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +# --------------------------------------------------------------------------- +# Global handles (populated at startup) +# --------------------------------------------------------------------------- +_engine: Optional[Engine] = None +_tokenizer: Optional[Any] = None + + +def _get_engine() -> Engine: + """Return the running engine or raise.""" + if _engine is None: + raise RuntimeError("Engine not initialised") + return _engine + + +# --------------------------------------------------------------------------- +# Pydantic request / response models +# --------------------------------------------------------------------------- + + +class GenerateRequest(BaseModel): + """Body for ``POST /generate``.""" + + text: Optional[Union[List[str], str]] = None + input_ids: Optional[Union[List[List[int]], List[int]]] = None + sampling_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None + image_data: Optional[Any] = None + audio_data: Optional[Any] = None + video_data: Optional[Any] = None + return_logprob: Optional[Union[List[bool], bool]] = None + logprob_start_len: Optional[Union[List[int], int]] = None + top_logprobs_num: Optional[Union[List[int], int]] = None + lora_path: Optional[Union[List[Optional[str]], str]] = None + session_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None + stream: bool = False + rid: Optional[Union[List[str], str]] = None + + model_config = {"extra": "allow"} # forward unknown keys as extra_options + + +# -- OpenAI-compatible models ----------------------------------------------- + + +class ImageUrl(BaseModel): + url: str + detail: Optional[str] = "auto" + + +class ContentPart(BaseModel): + type: str + text: Optional[str] = None + image_url: Optional[ImageUrl] = None + + +class ChatMessage(BaseModel): + role: str + content: Optional[Union[str, List[ContentPart]]] = None + name: Optional[str] = None + tool_calls: Optional[List[Any]] = None + tool_call_id: Optional[str] = None + + model_config = {"extra": "allow"} + + +class StreamOptions(BaseModel): + include_usage: Optional[bool] = False + continuous_usage_stats: Optional[bool] = False + + +class ToolFunction(BaseModel): + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + + +class Tool(BaseModel): + type: str = "function" + function: ToolFunction + + +class ChatCompletionRequest(BaseModel): + """OpenAI ``POST /v1/chat/completions`` body.""" + + model: str = "" + messages: List[ChatMessage] + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + max_tokens: Optional[int] = None + max_completion_tokens: Optional[int] = None + stream: bool = False + stream_options: Optional[StreamOptions] = None + stop: Optional[Union[str, List[str]]] = None + n: int = 1 + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + repetition_penalty: Optional[float] = None + seed: Optional[int] = None + logprobs: Optional[bool] = None + top_logprobs: Optional[int] = None + user: Optional[str] = None + # Tool calling + tools: Optional[List[Tool]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = None + # Reasoning control + separate_reasoning: bool = True + stream_reasoning: bool = True + # Pass-through to tokenizer.apply_chat_template (e.g. enable_thinking) + chat_template_kwargs: Optional[Dict[str, Any]] = None + + model_config = {"extra": "allow"} + + +class CompletionRequest(BaseModel): + """OpenAI ``POST /v1/completions`` body.""" + + model: str = "" + prompt: Union[str, List[str]] + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + max_tokens: Optional[int] = None + stream: bool = False + stream_options: Optional[StreamOptions] = None + stop: Optional[Union[str, List[str]]] = None + n: int = 1 + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + repetition_penalty: Optional[float] = None + seed: Optional[int] = None + echo: bool = False + logprobs: Optional[int] = None + user: Optional[str] = None + + model_config = {"extra": "allow"} + + +class AbortRequest(BaseModel): + rid: Optional[str] = None + + +# --------------------------------------------------------------------------- +# FastAPI application & lifespan +# --------------------------------------------------------------------------- + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Startup / shutdown hooks for the FastAPI app.""" + global _engine, _tokenizer + _engine = app.state.engine # type: ignore[attr-defined] + + # Load tokenizer in server process for apply_chat_template + cfg = get_global_config() + try: + from transformers import AutoTokenizer + + _tokenizer = AutoTokenizer.from_pretrained( + str(cfg.server.tokenizer_path), + trust_remote_code=cfg.server.trust_remote_code, + ) + logger.info( + "Loaded tokenizer for chat template: %s", cfg.server.tokenizer_path + ) + except Exception as e: + logger.warning("Failed to load tokenizer for chat template: %s", e) + + logger.info( + "HTTP server ready at http://%s:%s", + cfg.server.host, + cfg.server.port, + ) + yield + # Shutdown + if _engine is not None: + _engine.shutdown() + _engine = None + + +app = FastAPI(lifespan=lifespan) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# --------------------------------------------------------------------------- +# Exception handlers +# --------------------------------------------------------------------------- + + +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): + return ORJSONResponse( + content={"error": {"message": exc.detail, "code": exc.status_code}}, + status_code=exc.status_code, + ) + + +# --------------------------------------------------------------------------- +# Health / info endpoints +# --------------------------------------------------------------------------- + + +@app.get("/health") +@app.get("/health_generate") +async def health(): + """Liveness probe.""" + return Response(status_code=200) + + +@app.get("/model_info") +async def model_info(): + """Return basic model metadata.""" + cfg = get_global_config() + hf_cfg = cfg.model.hf_config + return { + "model_path": str(cfg.server.model_path), + "tokenizer_path": str(cfg.server.tokenizer_path), + "served_model_name": cfg.server.served_model_name, + "model_type": getattr(hf_cfg, "model_type", None) if hf_cfg else None, + "architectures": getattr(hf_cfg, "architectures", None) if hf_cfg else None, + } + + +@app.get("/server_info") +async def server_info(): + """Dump runtime server configuration.""" + import dataclasses as _dc + + cfg = get_global_config() + return _dc.asdict(cfg.server) + + +@app.get("/v1/models") +async def list_models(): + """OpenAI-compatible model listing.""" + cfg = get_global_config() + model_name = cfg.server.served_model_name or str(cfg.server.model_path) + return { + "object": "list", + "data": [_model_card(model_name)], + } + + +@app.get("/v1/models/{model_id:path}") +async def retrieve_model(model_id: str): + """OpenAI-compatible single model retrieval.""" + cfg = get_global_config() + model_name = cfg.server.served_model_name or str(cfg.server.model_path) + if model_id != model_name: + raise HTTPException( + status_code=404, + detail=f"Model '{model_id}' not found. Available: '{model_name}'", + ) + return _model_card(model_name) + + +def _model_card(model_name: str) -> Dict[str, Any]: + """Build an OpenAI-compatible Model object.""" + return { + "id": model_name, + "object": "model", + "created": int(time.time()), + "owned_by": "pymllm", + } + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +# Map internal finish reasons to OpenAI-standard values. +_FINISH_REASON_MAP = { + "eos": "stop", + "stop": "stop", + "length": "length", + "abort": "stop", +} + + +def _normalize_finish_reason(reason: Optional[str]) -> Optional[str]: + """Convert internal finish reason to OpenAI-compatible value.""" + if reason is None: + return None + return _FINISH_REASON_MAP.get(reason, reason) + + +def _build_sampling_params( + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + max_tokens: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + repetition_penalty: Optional[float] = None, + seed: Optional[int] = None, + **extra: Any, +) -> Dict[str, Any]: + """Build a sampling_params dict from OpenAI-style fields.""" + params: Dict[str, Any] = {} + if temperature is not None: + params["temperature"] = temperature + if top_p is not None: + params["top_p"] = top_p + if top_k is not None: + params["top_k"] = top_k + if max_tokens is not None: + params["max_new_tokens"] = max_tokens + if stop is not None: + params["stop"] = stop if isinstance(stop, list) else [stop] + if frequency_penalty is not None: + params["frequency_penalty"] = frequency_penalty + if presence_penalty is not None: + params["presence_penalty"] = presence_penalty + if repetition_penalty is not None: + params["repetition_penalty"] = repetition_penalty + if seed is not None: + params["seed"] = seed + params.update(extra) + return params + + +def _messages_to_prompt( + messages: List[ChatMessage], + chat_template_kwargs: Optional[Dict[str, Any]] = None, +) -> str: + """Render chat messages into a prompt string via the model's chat template. + + Uses ``tokenizer.apply_chat_template()`` when available (handles Llama, + Qwen, Mistral, etc. automatically). Falls back to ChatML format. + + Parameters + ---------- + chat_template_kwargs + Extra keyword arguments forwarded to ``apply_chat_template`` + (e.g. ``enable_thinking=True`` for Qwen3). + """ + # Flatten each message into a plain dict for the tokenizer. + msg_dicts: List[Dict[str, Any]] = [] + for msg in messages: + content = msg.content + if isinstance(content, list): + # Multimodal: extract only text parts for the prompt string. + text_parts = [p.text for p in content if p.type == "text" and p.text] + content = "\n".join(text_parts) if text_parts else "" + elif content is None: + content = "" + d: Dict[str, Any] = {"role": msg.role, "content": content} + if msg.name is not None: + d["name"] = msg.name + msg_dicts.append(d) + + tokenizer = _tokenizer + if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"): + try: + extra = dict(chat_template_kwargs) if chat_template_kwargs else {} + return tokenizer.apply_chat_template( + msg_dicts, + tokenize=False, + add_generation_prompt=True, + **extra, + ) + except Exception as e: + logger.warning("apply_chat_template failed, using fallback: %s", e) + + # Fallback: ChatML format (Qwen-style) + parts: List[str] = [] + for m in msg_dicts: + parts.append(f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>") + parts.append("<|im_start|>assistant\n") + return "\n".join(parts) + + +def _extract_image_data(messages: List[ChatMessage]) -> Optional[List[str]]: + """Extract image URLs / base64 strings from multimodal content parts.""" + images: List[str] = [] + for msg in messages: + if not isinstance(msg.content, list): + continue + for part in msg.content: + if part.type == "image_url" and part.image_url is not None: + images.append(part.image_url.url) + return images if images else None + + +def _make_completion_id() -> str: + return f"cmpl-{uuid.uuid4().hex[:24]}" + + +def _make_chat_completion_id() -> str: + return f"chatcmpl-{uuid.uuid4().hex[:24]}" + + +# --------------------------------------------------------------------------- +# Native generate endpoint +# --------------------------------------------------------------------------- + + +@app.api_route("/generate", methods=["POST", "PUT"]) +async def generate(obj: GenerateRequest, request: Request): + """Native generation endpoint. Supports SSE streaming.""" + engine = _get_engine() + + # Collect extra fields as extra_options + known = set(GenerateRequest.model_fields.keys()) + extra_options = {k: v for k, v in obj.model_dump().items() if k not in known} + + kwargs: Dict[str, Any] = { + "prompt": obj.text, + "input_ids": obj.input_ids, + "sampling_params": obj.sampling_params, + "image_data": obj.image_data, + "audio_data": obj.audio_data, + "video_data": obj.video_data, + "return_logprob": obj.return_logprob, + "logprob_start_len": obj.logprob_start_len, + "top_logprobs_num": obj.top_logprobs_num, + "lora_path": obj.lora_path, + "session_params": obj.session_params, + "stream": obj.stream, + "rid": obj.rid, + **extra_options, + } + # Strip None values so Engine defaults are used + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + if obj.stream: + + async def _stream() -> AsyncIterator[bytes]: + try: + async for chunk in engine.generate_async(**kwargs): + if await request.is_disconnected(): + break + # Skip empty intermediate chunks (e.g. special tokens + # stripped by the detokenizer) + if not chunk.get("delta") and not chunk.get("finished"): + continue + yield b"data: " + orjson.dumps(chunk) + b"\n\n" + except Exception as e: + err = {"error": {"message": str(e)}} + yield b"data: " + orjson.dumps(err) + b"\n\n" + yield b"data: [DONE]\n\n" + + return StreamingResponse(_stream(), media_type="text/event-stream") + + try: + results = [] + async for item in engine.generate_async(**kwargs): + results.append(item) + result = results[0] if len(results) == 1 else results + return ORJSONResponse(result) + except Exception as e: + logger.error("[generate] Error: %s", e) + raise HTTPException(status_code=400, detail=str(e)) + + +# --------------------------------------------------------------------------- +# OpenAI-compatible /v1/completions +# --------------------------------------------------------------------------- + + +@app.post("/v1/completions") +async def openai_completions(obj: CompletionRequest, request: Request): + """OpenAI-compatible text completion endpoint.""" + engine = _get_engine() + sp = _build_sampling_params( + temperature=obj.temperature, + top_p=obj.top_p, + top_k=obj.top_k, + max_tokens=obj.max_tokens, + stop=obj.stop, + frequency_penalty=obj.frequency_penalty, + presence_penalty=obj.presence_penalty, + repetition_penalty=obj.repetition_penalty, + seed=obj.seed, + ) + cfg = get_global_config() + model_name = obj.model or cfg.server.served_model_name or str(cfg.server.model_path) + include_usage = ( + obj.stream_options is not None and obj.stream_options.include_usage + ) + + if obj.stream: + + async def _stream() -> AsyncIterator[bytes]: + comp_id = _make_completion_id() + prompt_tokens = 0 + completion_tokens = 0 + try: + async for chunk in engine.generate_async( + prompt=obj.prompt, sampling_params=sp, stream=True + ): + if await request.is_disconnected(): + break + prompt_tokens = chunk.get("prompt_tokens", prompt_tokens) + completion_tokens = chunk.get("completion_tokens", completion_tokens) + delta_text = chunk.get("delta", "") + finish_reason = _normalize_finish_reason( + chunk.get("finished_reason") + ) + # Skip empty intermediate chunks + if not delta_text and finish_reason is None: + continue + sse: Dict[str, Any] = { + "id": comp_id, + "object": "text_completion", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "index": 0, + "text": delta_text, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + } + yield b"data: " + orjson.dumps(sse) + b"\n\n" + except Exception as e: + err = {"error": {"message": str(e)}} + yield b"data: " + orjson.dumps(err) + b"\n\n" + # Final usage-only chunk (OpenAI stream_options.include_usage) + if include_usage: + usage_chunk: Dict[str, Any] = { + "id": comp_id, + "object": "text_completion", + "created": int(time.time()), + "model": model_name, + "choices": [], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + yield b"data: " + orjson.dumps(usage_chunk) + b"\n\n" + yield b"data: [DONE]\n\n" + + return StreamingResponse(_stream(), media_type="text/event-stream") + + try: + results = [] + async for item in engine.generate_async( + prompt=obj.prompt, sampling_params=sp + ): + results.append(item) + choices = [] + prompt_tokens = 0 + completion_tokens = 0 + for i, r in enumerate(results): + choices.append( + { + "index": i, + "text": r.get("text", ""), + "logprobs": None, + "finish_reason": _normalize_finish_reason( + r.get("finished_reason", "stop") + ), + } + ) + prompt_tokens += r.get("prompt_tokens", 0) + completion_tokens += r.get("completion_tokens", 0) + + return ORJSONResponse( + { + "id": _make_completion_id(), + "object": "text_completion", + "created": int(time.time()), + "model": model_name, + "choices": choices, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + ) + except Exception as e: + logger.error("[v1/completions] Error: %s", e) + raise HTTPException(status_code=400, detail=str(e)) + + +# --------------------------------------------------------------------------- +# OpenAI-compatible /v1/chat/completions +# --------------------------------------------------------------------------- + + +@app.post("/v1/chat/completions") +async def openai_chat_completions(obj: ChatCompletionRequest, request: Request): + """OpenAI-compatible chat completion endpoint with reasoning & tool-call parsing.""" + engine = _get_engine() + cfg = get_global_config() + # Auto-enable thinking when reasoning_parser is configured and the + # client didn't explicitly set enable_thinking. + chat_kwargs = dict(obj.chat_template_kwargs) if obj.chat_template_kwargs else {} + if cfg.server.reasoning_parser and "enable_thinking" not in chat_kwargs: + chat_kwargs["enable_thinking"] = True + prompt = _messages_to_prompt(obj.messages, chat_template_kwargs=chat_kwargs or None) + image_data = _extract_image_data(obj.messages) + + # max_completion_tokens takes precedence over max_tokens (OpenAI convention) + max_tokens = obj.max_completion_tokens if obj.max_completion_tokens is not None else obj.max_tokens + + sp = _build_sampling_params( + temperature=obj.temperature, + top_p=obj.top_p, + top_k=obj.top_k, + max_tokens=max_tokens, + stop=obj.stop, + frequency_penalty=obj.frequency_penalty, + presence_penalty=obj.presence_penalty, + repetition_penalty=obj.repetition_penalty, + seed=obj.seed, + ) + cfg = get_global_config() + model_name = obj.model or cfg.server.served_model_name or str(cfg.server.model_path) + include_usage = ( + obj.stream_options is not None and obj.stream_options.include_usage + ) + + # Resolve parsers from server config + reasoning_type = cfg.server.reasoning_parser + tool_call_type = cfg.server.tool_call_parser + + gen_kwargs: Dict[str, Any] = { + "prompt": prompt, + "sampling_params": sp, + } + if image_data is not None: + gen_kwargs["image_data"] = image_data + + if obj.stream: + + async def _stream() -> AsyncIterator[bytes]: + from pymllm.parsers import ReasoningParser, ToolCallParser + + comp_id = _make_chat_completion_id() + created = int(time.time()) + first = True + prompt_tokens = 0 + completion_tokens = 0 + has_tool_calls = False # track across entire stream + + # Instantiate streaming parsers + r_parser = ( + ReasoningParser(reasoning_type, stream_reasoning=obj.stream_reasoning) + if reasoning_type and obj.separate_reasoning + else None + ) + tc_parser = ( + ToolCallParser(tool_call_type, tools=obj.tools) + if tool_call_type and obj.tools + else None + ) + + def _make_sse(delta: Dict[str, Any], finish: Optional[str] = None) -> bytes: + sse: Dict[str, Any] = { + "id": comp_id, + "object": "chat.completion.chunk", + "created": created, + "model": model_name, + "choices": [ + { + "index": 0, + "delta": delta, + "logprobs": None, + "finish_reason": finish, + } + ], + } + return b"data: " + orjson.dumps(sse) + b"\n\n" + + try: + async for chunk in engine.generate_async(**gen_kwargs, stream=True): + if await request.is_disconnected(): + break + prompt_tokens = chunk.get("prompt_tokens", prompt_tokens) + completion_tokens = chunk.get("completion_tokens", completion_tokens) + + raw_delta = chunk.get("delta", "") + finish_reason = _normalize_finish_reason( + chunk.get("finished_reason") + ) + + # --- Phase 1: reasoning parser --- + reasoning_delta = "" + content_delta = raw_delta + if r_parser and raw_delta: + reasoning_delta, content_delta = r_parser.parse_stream_chunk( + raw_delta + ) + + # --- Phase 2: tool-call parser --- + tool_items: list = [] + if tc_parser and content_delta: + content_delta, tool_items = tc_parser.parse_stream_chunk( + content_delta + ) + + # --- Emit chunks --- + # Role chunk (first) + if first: + yield _make_sse({"role": "assistant"}) + first = False + + # Reasoning content + if reasoning_delta: + yield _make_sse({"reasoning_content": reasoning_delta}) + + # Tool call deltas + if tool_items: + has_tool_calls = True + for tc in tool_items: + yield _make_sse({"tool_calls": [tc.to_openai_dict()]}) + + # Normal content + if content_delta: + yield _make_sse({"content": content_delta}) + + # Finish + if finish_reason is not None: + # Flush remaining tool call data + if tc_parser: + remaining = tc_parser.flush() + for tc in remaining: + has_tool_calls = True + yield _make_sse({"tool_calls": [tc.to_openai_dict()]}) + if has_tool_calls: + finish_reason = "tool_calls" + yield _make_sse({}, finish=finish_reason) + + except Exception as e: + err = {"error": {"message": str(e)}} + yield b"data: " + orjson.dumps(err) + b"\n\n" + # Final usage-only chunk + if include_usage: + usage_chunk: Dict[str, Any] = { + "id": comp_id, + "object": "chat.completion.chunk", + "created": created, + "model": model_name, + "choices": [], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + yield b"data: " + orjson.dumps(usage_chunk) + b"\n\n" + yield b"data: [DONE]\n\n" + + return StreamingResponse(_stream(), media_type="text/event-stream") + + # -- Non-streaming -- + try: + from pymllm.parsers import ReasoningParser, ToolCallParser + + r = {} + async for item in engine.generate_async(**gen_kwargs): + r = item + prompt_tokens = r.get("prompt_tokens", 0) + completion_tokens = r.get("completion_tokens", 0) + text = r.get("text", "") + finish_reason = _normalize_finish_reason(r.get("finished_reason", "stop")) + + # Parse reasoning + reasoning_content = None + if reasoning_type and obj.separate_reasoning: + rp = ReasoningParser(reasoning_type) + reasoning_content, text = rp.parse_non_stream(text) + + # Parse tool calls + tool_calls_list = None + if tool_call_type and obj.tools: + tp = ToolCallParser(tool_call_type, tools=obj.tools) + if tp.has_tool_call(text): + text, parsed_calls = tp.parse_non_stream(text) + if parsed_calls: + tool_calls_list = [tc.to_openai_dict(streaming=False) for tc in parsed_calls] + finish_reason = "tool_calls" + + message: Dict[str, Any] = {"role": "assistant", "content": text or None} + if reasoning_content: + message["reasoning_content"] = reasoning_content + if tool_calls_list: + message["tool_calls"] = tool_calls_list + + return ORJSONResponse( + { + "id": _make_chat_completion_id(), + "object": "chat.completion", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "index": 0, + "message": message, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + ) + except Exception as e: + logger.error("[v1/chat/completions] Error: %s", e) + raise HTTPException(status_code=400, detail=str(e)) + + +# --------------------------------------------------------------------------- +# Administrative endpoints +# --------------------------------------------------------------------------- + + +@app.api_route("/flush_cache", methods=["GET", "POST"]) +async def flush_cache(): + """Placeholder cache flush.""" + return Response(content="Cache flushed.\n", status_code=200) + + +@app.post("/abort_request") +async def abort_request(obj: AbortRequest): + """Abort a running request by rid.""" + engine = _get_engine() + if obj.rid and engine._rr_process is not None: + await engine._rr_process.abort_request(obj.rid) + return Response(status_code=200) + raise HTTPException(status_code=400, detail="Missing or invalid rid") + + +# --------------------------------------------------------------------------- +# Prepare args helper +# --------------------------------------------------------------------------- + + +def _prepare_args(): + """Parse CLI arguments into the global config singleton.""" + parser = make_args() + read_args(parser=parser) + + +# --------------------------------------------------------------------------- +# Server launcher +# --------------------------------------------------------------------------- + + +def launch_server(): + """Launch the pymllm Engine then start the uvicorn HTTP server. + + This function mirrors sglang's ``launch_server``: it first boots all engine + subprocesses (tokenizer, scheduler, model-runner, detokenizer) and then + hands off to uvicorn to serve HTTP traffic. + """ + _prepare_args() + cfg = get_global_config() + + engine = Engine() + engine.launch() + + # Attach engine to app.state so the lifespan hook can pick it up. + app.state.engine = engine # type: ignore[attr-defined] + + logger.info( + "Starting HTTP server on %s:%s (root_path=%r)", + cfg.server.host, + cfg.server.port, + cfg.server.fastapi_root_path, + ) + + uvicorn.run( + app, + host=cfg.server.host, + port=cfg.server.port, + root_path=cfg.server.fastapi_root_path, + log_level=cfg.server.log_level, + timeout_keep_alive=5, + loop="uvloop", + ) + + +def main(): + """CLI entry point.""" + launch_server() + + +if __name__ == "__main__": + main() diff --git a/pymllm/tests/README.md b/pymllm/tests/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/pymllm/tests/test_vocab_parallel_embedding.py b/pymllm/tests/test_vocab_parallel_embedding.py new file mode 100644 index 000000000..44148f983 --- /dev/null +++ b/pymllm/tests/test_vocab_parallel_embedding.py @@ -0,0 +1,312 @@ +"""Tests for VocabParallelEmbedding layer. + +This module tests the VocabParallelEmbedding layer with and without +tensor parallelism. +""" + +import os +import logging +import pytest +import torch +import torch.nn as nn +import torch.multiprocessing as mp +from typing import Callable + +from pymllm.layers import VocabParallelEmbedding +from pymllm.orchestrator import initialize_model_parallel +from pymllm.orchestrator.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +# Show runtime init logs during test execution. +logging.basicConfig(level=logging.INFO, force=True) +logging.getLogger().setLevel(logging.INFO) + + +# ============================================================================= +# Helper: weight loading +# ============================================================================= +def load_weight(param: nn.Parameter, loaded_weight: torch.Tensor) -> None: + """Load weight using the weight_loader attached to param attribute.""" + weight_loader = getattr(param, "weight_loader", None) + if weight_loader is None: + # Fallback: direct copy + param.data.copy_(loaded_weight) + else: + # Call the loader attached to param + weight_loader(param, loaded_weight) + + +# ============================================================================= +# Real distributed tests with world_size=8 on CUDA +# ============================================================================= +def run_worker_tp8_cuda( + rank: int, + local_rank: int, + world_size: int, + local_world_size: int, + test_func: Callable, + return_dict: dict, +): + """Worker function for multi-process testing with TP=8 on CUDA. + + Args: + rank: Global rank across all nodes + local_rank: Local rank within this node (used for GPU binding) + world_size: Total number of processes across all nodes + local_world_size: Number of processes on this node + test_func: Test function to run + return_dict: Shared dict for returning results + """ + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + + # Set device using local_rank (binds to GPU 0,1,2,3 on this node) + torch.cuda.set_device(local_rank) + + torch.distributed.init_process_group( + backend="nccl", + rank=rank, + world_size=world_size, + ) + + initialize_model_parallel(tensor_model_parallel_size=8) + + try: + result = test_func(rank, local_rank, world_size) + return_dict[rank] = result + except Exception as e: + import traceback + + return_dict[rank] = f"ERROR: {e}\n{traceback.format_exc()}" + finally: + torch.distributed.destroy_process_group() + + +def embedding_forward_tp8_worker_cuda(rank: int, local_rank: int, world_size: int): + """Test forward pass with real TP=8 on CUDA. + + Args: + rank: Global rank + local_rank: Local rank within this node (for logging/debugging) + world_size: Total world size + """ + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + assert tp_size == 8, f"Rank {rank}: tp_size should be 8" + assert tp_rank == rank, f"Rank {rank}: tp_rank mismatch" + + vocab_size = 1024 + embed_dim = 64 + # .cuda() uses the device set by torch.cuda.set_device(local_rank) + layer = VocabParallelEmbedding(vocab_size, embed_dim).cuda() + + # Verify the layer is on the correct GPU + assert layer.weight.device.index == local_rank, ( + f"Rank {rank}: weight should be on GPU {local_rank}, got {layer.weight.device}" + ) + + expected_shard_size = vocab_size // 8 + assert layer.num_embeddings_per_partition == expected_shard_size + assert layer.weight.shape == (expected_shard_size, embed_dim) + + # Each rank initializes its own shard with known pattern + with torch.no_grad(): + layer.weight.fill_(float(rank + 1)) # Rank 0: 1.0, Rank 1: 2.0, ... + + # Create input on the correct GPU + input_ids = torch.tensor([[0, 128, 256, 384], [512, 640, 768, 896]], device="cuda") + + output = layer(input_ids) + assert output.shape == (2, 4, embed_dim) + + # Verify output is on correct GPU + assert output.device.index == local_rank, ( + f"Rank {rank}: output should be on GPU {local_rank}, got {output.device}" + ) + + if rank == 0: + # Each token is owned by exactly one TP rank. Since each rank fills its + # local shard with (rank + 1), post-all-reduce output must match below. + expected_token_values = torch.tensor( + [[1, 2, 3, 4], [5, 6, 7, 8]], + device=output.device, + dtype=output.dtype, + ) + expected_output = expected_token_values.unsqueeze(-1).expand(-1, -1, embed_dim) + + if torch.equal(output, expected_output): + return "PASSED" + return "FAILED: embedding output does not match expected TP aggregation" + + return "OK" + + +def weight_loading_tp8_worker_cuda(rank: int, local_rank: int, world_size: int): + """Test weight loading with real TP=8 on CUDA. + + Args: + rank: Global rank + local_rank: Local rank within this node (for GPU binding verification) + world_size: Total world size + """ + vocab_size = 1024 + embed_dim = 64 + layer = VocabParallelEmbedding(vocab_size, embed_dim).cuda() + + # Verify the layer is on the correct GPU + assert layer.weight.device.index == local_rank, ( + f"Rank {rank}: weight should be on GPU {local_rank}, got {layer.weight.device}" + ) + + full_weight = torch.randn(vocab_size, embed_dim) + load_weight(layer.weight, full_weight.cuda()) + + shard_size = vocab_size // 8 + start_idx = rank * shard_size + end_idx = start_idx + shard_size + expected_shard = full_weight[start_idx:end_idx] + + if not torch.allclose(layer.weight.cpu(), expected_shard): + return f"FAILED: shard mismatch at rank {rank}" + + if rank == 0: + gathered_shards = [layer.weight.cpu().clone()] + for other_rank in range(1, 8): + other_shard = full_weight[ + other_rank * shard_size : (other_rank + 1) * shard_size + ] + gathered_shards.append(other_shard) + + reconstructed = torch.cat(gathered_shards, dim=0) + if torch.allclose(reconstructed, full_weight): + return "PASSED" + else: + return "FAILED: reconstruction mismatch" + + return "OK" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(torch.cuda.device_count() < 8, reason="Requires at least 8 GPUs") +class TestVocabParallelEmbeddingRealTP8: + """Real distributed tests with world_size=8 and TP=8 on CUDA.""" + + def test_forward_pass_tp8_real(self): + """Test forward pass with real TP=8 using 8 processes on CUDA.""" + world_size = 8 + local_world_size = 8 # Single node with 8 GPUs + + mp.set_start_method("spawn", force=True) + + manager = mp.Manager() + return_dict = manager.dict() + + processes = [] + for rank in range(world_size): + # In single-node setup, local_rank == rank + local_rank = rank + p = mp.Process( + target=run_worker_tp8_cuda, + args=( + rank, + local_rank, + world_size, + local_world_size, + embedding_forward_tp8_worker_cuda, + return_dict, + ), + ) + p.start() + processes.append(p) + + for p in processes: + p.join(timeout=120) + if p.is_alive(): + p.terminate() + p.join() + + for rank in range(world_size): + result = return_dict.get(rank, "TIMEOUT") + if rank == 0: + assert result == "PASSED", f"Rank {rank} failed: {result}" + else: + assert "ERROR" not in str(result), f"Rank {rank} error: {result}" + + def test_weight_loading_tp8_real(self): + """Test weight loading with real TP=8 using 8 processes on CUDA.""" + world_size = 8 + local_world_size = 8 # Single node with 8 GPUs + + mp.set_start_method("spawn", force=True) + + manager = mp.Manager() + return_dict = manager.dict() + + processes = [] + for rank in range(world_size): + # In single-node setup, local_rank == rank + local_rank = rank + p = mp.Process( + target=run_worker_tp8_cuda, + args=( + rank, + local_rank, + world_size, + local_world_size, + weight_loading_tp8_worker_cuda, + return_dict, + ), + ) + p.start() + processes.append(p) + + for p in processes: + p.join(timeout=120) + if p.is_alive(): + p.terminate() + p.join() + + for rank in range(world_size): + result = return_dict.get(rank, "TIMEOUT") + if rank == 0: + assert result == "PASSED", f"Rank {rank} failed: {result}" + else: + assert "ERROR" not in str(result), f"Rank {rank} error: {result}" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestVocabParallelEmbeddingCUDA: + """Tests for non-parallel TP=1 mode on CUDA.""" + + @pytest.fixture(autouse=True) + def setup_config(self): + import pymllm.orchestrator.parallel_state as ps + ps._TP_SIZE = 1 + ps._TP_RANK = 0 + yield + ps._TP_SIZE = 1 + ps._TP_RANK = 0 + + def test_cuda_forward(self): + layer = VocabParallelEmbedding(1000, 512).cuda() + input_ids = torch.randint(0, 1000, (4, 32), device="cuda") + + output = layer(input_ids) + + assert output.device.type == "cuda" + assert output.shape == (4, 32, 512) + + def test_cuda_weight_loader(self): + layer = VocabParallelEmbedding(100, 64).cuda() + + cpu_weight = torch.randn(100, 64) + load_weight(layer.weight, cpu_weight.cuda()) + + assert torch.allclose(layer.weight.cpu(), cpu_weight) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/pymllm/utils/mllm_convertor_server/service.py b/pymllm/utils/mllm_convertor_server/service.py deleted file mode 100644 index ea8e2bec7..000000000 --- a/pymllm/utils/mllm_convertor_server/service.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) MLLM Team. -# Licensed under the MIT License. diff --git a/pyproject.toml b/pyproject.toml index 703d4456a..d752ddc1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = [ - "scikit-build-core>=0.11.0", "apache-tvm-ffi" + "scikit-build-core>=0.11.0", "apache-tvm-ffi == 0.1.8" ] build-backend = "scikit_build_core.build" @@ -21,7 +21,7 @@ dependencies=[ "packaging", "pytest", "pytest-html", - "apache-tvm-ffi == 0.1.0b4", + "apache-tvm-ffi == 0.1.8.post2", "pyyaml >= 6.0.2", "openai", "modelscope", @@ -30,14 +30,18 @@ dependencies=[ "typer", "torch", "torchao", + "pyfiglet", + "termcolor", ] [project.optional-dependencies] -cuda = ["tilelang"] +cuda = ["tilelang", "flashinfer-python", "pyzmq"] [project.scripts] -mllm-convertor = "pymllm.utils.mllm_convertor:main" -mllm-service = "pymllm.service.tools:cli_app" +pymllm = "pymllm.__main__:main" +mllm-convertor = "pymllm.mobile.utils.mllm_convertor:main" +mllm-service = "pymllm.mobile.service.tools:cli_app" +pymllm-server = "pymllm.server.launch:main" [tool.setuptools.exclude-package-data] "*" = ["*.pyc"]