Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
486 changes: 486 additions & 0 deletions .claude/skills/impl-jit-kernel/SKILL.md

Large diffs are not rendered by default.

44 changes: 44 additions & 0 deletions .claude/skills/update-codeowners/SKILL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
---
name: update-codeowners
description: Updates CODEOWNERS entries safely with consistent path and owner formatting. Use when the user asks to add, remove, or modify CODEOWNERS rules, ownership mappings, reviewers, or module maintainers.
---

# Update CODEOWNERS

## Goal
Maintain `CODEOWNERS` accurately while preserving the repository's existing section/comment style.

## Workflow
1. Read the current `CODEOWNERS` file before editing.
2. Identify requested changes as one of:
- Add new path rule
- Modify owners for existing path rule
- Remove obsolete path rule
- Reorganize section comments (only if requested)
3. Update rules in place instead of creating duplicates for the same path.
4. Keep existing section headers and comment style unless the user asks to refactor structure.
5. Return a concise changelog describing which paths were added, changed, or removed.

## Rule Format
- Use one rule per line: `<path-pattern> <owner1> <owner2> ...`
- Owners must be GitHub handles prefixed with `@`.
- Keep path style consistent with the file (in this repo, path patterns typically start with `/`).
- Do not leave rules with empty owner lists.

## Editing Guidelines
- Prefer minimal edits near related sections.
- If a path already exists, update that line instead of adding a second conflicting line.
- If a new rule logically belongs to an existing section, place it in that section.
- Preserve human-readable grouping and blank lines.
- Keep comments intact unless they are clearly outdated and the user asked for cleanup.

## Validation Checklist
- [ ] Every non-comment, non-empty line has at least one owner.
- [ ] Every owner token starts with `@`.
- [ ] No accidental duplicate rule for the exact same path pattern.
- [ ] Existing comments/sections were preserved unless explicitly changed.

## Example Requests
- "Add `/mllm/models/new_model/ @alice @bob` under models."
- "Change `/core/Storage` owner to `@team-core`."
- "Remove ownership rule for deprecated path `/legacy/`."
2 changes: 1 addition & 1 deletion .codespellrc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[codespell]
ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS, bfloat, constexpr, cuda, dlpack, expt, forceinline, ifndef, linalg, LPBQ, mllm, pymllm, Quantizaton, Qwen, ROCM, silu, torchao
ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS, bfloat, constexpr, cuda, dlpack, expt, forceinline, ifndef, linalg, LPBQ, mllm, pymllm, Quantizaton, Qwen, ROCM, silu, torchao, flashinfer
skip = *.json,*.jsonl,*.patch,*.txt
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
.cache/
.tmp/
compile_commands.json
.claude/
settings.local.json

# MLLM Team Specific
tasks/mllmteam*
Expand Down
Binary file added assets/pymllm-arch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions mllm-kernel/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ build-py/
.vscode/settings.json
compile_commands.json
.clangd
.pytest_cache/
33 changes: 16 additions & 17 deletions mllm-kernel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,31 +80,30 @@ y = add_constant(x, 8)

Use the helpers in `mllm_kernel.jit_utils`:

- `load_cpu_jit`
- `load_cuda_jit`
- `jit`
- `make_cpp_args`
- `cache_once`

Example pattern:
Recommended pattern (CPU example):

```python
import torch
from mllm_kernel.jit_utils import cache_once, load_cpu_jit, make_cpp_args

@cache_once
def _jit_my_kernel_module(param: int):
args = make_cpp_args(param)
return load_cpu_jit(
"my_kernel",
*args,
cpp_files=["my_kernel.cpp"],
cpp_wrappers=[("my_kernel", f"my_namespace::my_kernel<{args}>")],
)
import mllm_kernel

@mllm_kernel.jit(
args=16,
device="cpu",
cpp_files=["my_kernel.cpp"],
cpp_wrappers=[("my_kernel", "my_namespace::my_kernel<16>")],
func_name="my_kernel",
)
def _my_kernel_16(compiled_module, dst: torch.Tensor, src: torch.Tensor) -> None:
compiled_module.my_kernel(dst, src)

def my_kernel(src: torch.Tensor, param: int) -> torch.Tensor:
if param != 16:
raise ValueError("This demo only supports param=16.")
dst = torch.empty_like(src)
module = _jit_my_kernel_module(param)
module.my_kernel(dst, src)
_my_kernel_16(dst, src)
return dst
```

Expand Down
218 changes: 218 additions & 0 deletions mllm-kernel/benchmarks/bench_create_kv_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
"""Benchmark create_kv_indices vs naive torch gather using torch.profiler.

Example:
python benchmarks/bench_create_kv_indices.py --batch-size 512 --max-reqs 2048 --max-ctx 4096
"""

from __future__ import annotations

import argparse

import torch
from torch.profiler import ProfilerActivity, profile

from mllm_kernel.cuda.jit.create_kv_indices import create_kv_indices


def _make_batch(
*,
max_reqs: int,
max_ctx: int,
batch_size: int,
use_start_offsets: bool,
device: torch.device,
seed: int,
):
g_cuda = torch.Generator(device=device).manual_seed(seed)
g_cpu = torch.Generator(device="cpu").manual_seed(seed)

req_to_token = torch.arange(
max_reqs * max_ctx, dtype=torch.int32, device=device
).reshape(max_reqs, max_ctx)

assert batch_size <= max_reqs
req_pool_indices = torch.randperm(max_reqs, generator=g_cuda, device=device)[
:batch_size
].to(torch.int32)

page_kernel_lens_list = []
kv_start_idx_list = []
for _ in range(batch_size):
L = int(torch.randint(1, max_ctx, (1,), generator=g_cpu).item())
if use_start_offsets:
start_max = max_ctx - L
start = int(torch.randint(0, max(start_max, 1), (1,), generator=g_cpu).item())
else:
start = 0
page_kernel_lens_list.append(L)
kv_start_idx_list.append(start)

page_kernel_lens = torch.tensor(
page_kernel_lens_list, dtype=torch.int32, device=device
)
kv_start_idx = torch.tensor(kv_start_idx_list, dtype=torch.int32, device=device)

kv_indptr = torch.empty(batch_size + 1, dtype=torch.int32, device=device)
kv_indptr[0] = 0
kv_indptr[1:] = torch.cumsum(page_kernel_lens, dim=0)

kv_indices = torch.empty(
int(kv_indptr[-1].item()), dtype=torch.int32, device=device
)

return (
req_to_token,
req_pool_indices,
page_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
)


def _profile(
name: str, fn, *, warmup: int, iters: int, row_limit: int, trace_path: str | None
):
for _ in range(warmup):
fn()
torch.cuda.synchronize()

with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=False,
profile_memory=False,
with_stack=False,
) as prof:
for _ in range(iters):
fn()
torch.cuda.synchronize()

events = prof.key_averages()
time_attr = (
"self_cuda_time_total"
if events and hasattr(events[0], "self_cuda_time_total")
else "self_device_time_total"
)
sort_key = (
"self_cuda_time_total"
if time_attr == "self_cuda_time_total"
else "self_device_time_total"
)
total_us = sum(float(getattr(evt, time_attr, 0.0)) for evt in events)
avg_us = total_us / max(iters, 1)

print(f"\n=== {name} ===")
print(
prof.key_averages().table(
sort_by=sort_key,
row_limit=row_limit,
)
)
print(f"{name} total self device time: {total_us:.2f} us")
print(f"{name} avg self device time/iter: {avg_us:.2f} us")

if trace_path:
prof.export_chrome_trace(trace_path)
print(f"{name} trace exported: {trace_path}")

return avg_us


def main() -> None:
parser = argparse.ArgumentParser(
description="Benchmark create_kv_indices vs naive torch gather",
)
parser.add_argument("--batch-size", type=int, default=512)
parser.add_argument("--max-reqs", type=int, default=2048)
parser.add_argument("--max-ctx", type=int, default=4096)
parser.add_argument("--warmup", type=int, default=50)
parser.add_argument("--iters", type=int, default=200)
parser.add_argument("--row-limit", type=int, default=20)
parser.add_argument("--export-trace-dir", type=str, default="")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--use-start-offsets",
action="store_true",
help="Enable non-zero kv_start_idx to emulate sliding-window decode",
)
args = parser.parse_args()

if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark")

torch.manual_seed(args.seed)
device = torch.device("cuda")

(
req_to_token,
req_pool_indices,
page_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
) = _make_batch(
max_reqs=args.max_reqs,
max_ctx=args.max_ctx,
batch_size=args.batch_size,
use_start_offsets=args.use_start_offsets,
device=device,
seed=args.seed,
)

print("=== create_kv_indices profiler benchmark ===")
print(
f"batch_size={args.batch_size}, max_reqs={args.max_reqs}, max_ctx={args.max_ctx}, "
f"use_start_offsets={args.use_start_offsets}"
)
print(f"warmup={args.warmup}, iters={args.iters}, row_limit={args.row_limit}")

trace_dir = args.export_trace_dir.strip()
kernel_trace = f"{trace_dir}/create_kv_indices_trace.json" if trace_dir else None
torch_trace = f"{trace_dir}/torch_gather_trace.json" if trace_dir else None

def _run_kernel_once():
create_kv_indices(
req_to_token,
req_pool_indices,
page_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
)

def _run_torch_once():
# Torch reference implementation on device: gather per-sequence ranges
# from req_to_token into a flat buffer.
out = []
for i in range(args.batch_size):
req = req_pool_indices[i].item()
start = kv_start_idx[i].item() if args.use_start_offsets else 0
L = page_kernel_lens[i].item()
row = req_to_token[req, start : start + L]
out.append(row)
torch.cat(out, out=kv_indices)

kernel_avg_us = _profile(
"create_kv_indices",
_run_kernel_once,
warmup=args.warmup,
iters=args.iters,
row_limit=args.row_limit,
trace_path=kernel_trace,
)

torch_avg_us = _profile(
"torch_reference",
_run_torch_once,
warmup=args.warmup,
iters=args.iters,
row_limit=args.row_limit,
trace_path=torch_trace,
)

speedup = torch_avg_us / max(kernel_avg_us, 1e-12)
print(f"\nSpeedup: {speedup:.3f}x")


if __name__ == "__main__":
main()
Loading