diff --git a/examples/README.md b/examples/README.md
index 470030e..215a405 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -1,43 +1,87 @@
# PyGPUkit Examples
+## Directory Structure
+
+```
+examples/
+├── benchmarks/ # Performance benchmarks
+├── chat/ # Chat CLI applications
+├── demos/archived/ # Version-specific demos (historical)
+├── demo_*.py # Current feature demos
+├── tts.py # Text-to-speech example
+└── whisper_realtime_stt.py # Speech-to-text example
+```
+
## Requirements
-- NVIDIA GPU with CUDA support
-- CUDA Toolkit 12.x
+- NVIDIA GPU with SM >= 80 (Ampere or newer)
+- CUDA Toolkit 12.x or 13.x
- Built native module (`_pygpukit_native`)
-## Examples
+## Quick Start
-### demo_gpu.py
-Basic GPU operations demo using the native C++ backend directly.
+### Chat CLI
```bash
+# Standard chat (Qwen)
+python examples/chat/chat_cli.py
+
+# With Triton backend
+python examples/chat/chat_cli_triton.py
+
+# MoE models (Qwen3)
+python examples/chat/chat_cli_moe.py
+
+# Thinking mode (Qwen3-8B-Thinking)
+python examples/chat/chat_cli_thinking.py
+```
+
+### Demos
+
+```bash
+# Basic GPU operations
python examples/demo_gpu.py
+
+# CUDA Graph for LLM inference
+python examples/demo_cuda_graph.py
+
+# End-to-end LLM demo
+python examples/demo_llm_e2e.py
+
+# Qwen3 model demo
+python examples/demo_qwen3.py
```
-### demo_optimized.py
-Performance comparison showing zero-copy optimizations.
+### Benchmarks
```bash
-python examples/demo_optimized.py
+# Matrix multiplication benchmark
+python examples/benchmarks/benchmark_matmul.py
+
+# CUDA Graph LLM benchmark
+python examples/benchmarks/bench_cuda_graph_llm.py
+
+# Compare with cuBLAS
+python examples/benchmarks/benchmark_compare.py
```
-### demo_v01.py
-Simple v0.1 feature demonstration (CPU simulation fallback).
+### Speech/Audio
```bash
-python examples/demo_v01.py
+# Text-to-speech (Kokoro)
+python examples/tts.py
+
+# Real-time speech-to-text (Whisper)
+python examples/whisper_realtime_stt.py
```
## Building Native Module
```bash
-cd native
-mkdir build && cd build
-cmake .. -DCMAKE_BUILD_TYPE=Release
-cmake --build . --config Release
-```
+# From project root using build script
+./build.sh 86 # RTX 3090 Ti
+./build.sh 120a # RTX 5090
-Copy the built module to `src/pygpukit/`:
-- Linux: `_pygpukit_native.cpython-3xx-x86_64-linux-gnu.so`
-- Windows: `_pygpukit_native.cp3xx-win_amd64.pyd`
+# Or manually with pip
+pip install -e . -v
+```
diff --git a/examples/bench_cuda_graph_llm.py b/examples/benchmarks/bench_cuda_graph_llm.py
similarity index 100%
rename from examples/bench_cuda_graph_llm.py
rename to examples/benchmarks/bench_cuda_graph_llm.py
diff --git a/examples/benchmark_compare.py b/examples/benchmarks/benchmark_compare.py
similarity index 100%
rename from examples/benchmark_compare.py
rename to examples/benchmarks/benchmark_compare.py
diff --git a/examples/benchmark_large.py b/examples/benchmarks/benchmark_large.py
similarity index 100%
rename from examples/benchmark_large.py
rename to examples/benchmarks/benchmark_large.py
diff --git a/examples/benchmark_matmul.py b/examples/benchmarks/benchmark_matmul.py
similarity index 100%
rename from examples/benchmark_matmul.py
rename to examples/benchmarks/benchmark_matmul.py
diff --git a/examples/benchmark_tiled_matmul.py b/examples/benchmarks/benchmark_tiled_matmul.py
similarity index 100%
rename from examples/benchmark_tiled_matmul.py
rename to examples/benchmarks/benchmark_tiled_matmul.py
diff --git a/examples/chat_cli.py b/examples/chat/chat_cli.py
similarity index 100%
rename from examples/chat_cli.py
rename to examples/chat/chat_cli.py
diff --git a/examples/chat_cli_moe.py b/examples/chat/chat_cli_moe.py
similarity index 97%
rename from examples/chat_cli_moe.py
rename to examples/chat/chat_cli_moe.py
index 8845b3b..5ad60c7 100644
--- a/examples/chat_cli_moe.py
+++ b/examples/chat/chat_cli_moe.py
@@ -1,572 +1,572 @@
-#!/usr/bin/env python3
-"""
-PyGPUkit - MoE (Mixture of Experts) Chat CLI
-
-A minimal chat interface for MoE models (Mixtral, Qwen3-MoE, etc.).
-Supports multiple chat templates with auto-detection.
-
-Usage:
- python examples/chat_cli_moe.py --model /path/to/model.safetensors.index.json --tokenizer /path/to/tokenizer.json
-
-Example (Qwen3-30B-A3B MoE):
- python examples/chat_cli_moe.py \
- --model /path/to/Qwen3-30B-A3B/model.safetensors.index.json \
- --tokenizer /path/to/Qwen3-30B-A3B/tokenizer.json
-
-Example (Mixtral-8x7B):
- python examples/chat_cli_moe.py \
- --model /path/to/Mixtral-8x7B/model.safetensors.index.json \
- --tokenizer /path/to/Mixtral-8x7B/tokenizer.json
-
-Example with explicit chat template:
- python examples/chat_cli_moe.py \
- --model /path/to/model --chat-template qwen
-
-Example with CUDA Graph (faster decode):
- python examples/chat_cli_moe.py \
- --model /path/to/model --cuda-graph
-
-Supported chat templates:
- qwen - Qwen2/Qwen3 (<|im_start|>...<|im_end|>)
- mistral - Mistral/Mixtral ([INST]...[/INST])
- llama2 - LLaMA 2 (<>...<>)
- llama3 - LLaMA 3 (<|start_header_id|>...<|eot_id|>)
- chatml - Generic ChatML
-
-Commands:
- /clear - Clear conversation history
- /quit - Exit chat
-"""
-
-from __future__ import annotations
-
-import argparse
-import os
-import sys
-import time
-
-# Fix Windows console encoding for Unicode output
-if sys.platform == "win32":
- sys.stdout.reconfigure(encoding="utf-8")
- sys.stderr.reconfigure(encoding="utf-8")
-
-# Suppress cuBLASLt debug output
-os.environ.setdefault("PYGPUKIT_CUBLASLT_DEBUG", "0")
-
-import numpy as np
-
-
-def logits_to_f32(logits_gpu) -> np.ndarray:
- """Convert logits GPU array to numpy float32."""
- logits_np = logits_gpu.to_numpy()
- if logits_np.dtype == np.uint16:
- # bf16 stored as uint16 - convert to fp32
- return (logits_np.astype(np.uint32) << 16).view(np.float32)
- return logits_np.astype(np.float32)
-
-
-def _build_byte_decoder() -> dict[str, int]:
- """Build the unicode-to-byte mapping used by GPT-2/Mistral style tokenizers."""
- bs = (
- list(range(ord("!"), ord("~") + 1))
- + list(range(ord("\xa1"), ord("\xac") + 1))
- + list(range(ord("\xae"), ord("\xff") + 1))
- )
- cs = bs[:]
- n = 0
- for b in range(256):
- if b not in bs:
- bs.append(b)
- cs.append(256 + n)
- n += 1
- return {chr(c): b for b, c in zip(bs, cs)}
-
-
-_BYTE_DECODER = _build_byte_decoder()
-
-
-def _token_str_to_bytes(token_str: str) -> bytes:
- """Convert a token string to raw bytes."""
- result = []
- for char in token_str:
- if char in _BYTE_DECODER:
- result.append(_BYTE_DECODER[char])
- else:
- result.extend(char.encode("utf-8"))
- return bytes(result)
-
-
-class StreamingDecoder:
- """Streaming decoder for UTF-8 safe output."""
-
- def __init__(self, tokenizer):
- self.tokenizer = tokenizer
- self.pending_bytes = b""
- self._cache: dict[int, bytes] = {}
-
- def _get_token_bytes(self, token_id: int) -> bytes:
- cached = self._cache.get(token_id)
- if cached is not None:
- return cached
- token_str = self.tokenizer.id_to_token(token_id)
- if token_str is None:
- result = b""
- else:
- result = _token_str_to_bytes(token_str)
- self._cache[token_id] = result
- return result
-
- def add_token(self, token_id: int) -> str:
- new_bytes = self._get_token_bytes(token_id)
- if not new_bytes:
- return ""
-
- all_bytes = self.pending_bytes + new_bytes
- valid_end = 0
- i = 0
- while i < len(all_bytes):
- byte = all_bytes[i]
- if byte < 0x80:
- valid_end = i + 1
- i += 1
- elif byte < 0xC0:
- i += 1
- elif byte < 0xE0:
- if i + 1 < len(all_bytes) and 0x80 <= all_bytes[i + 1] < 0xC0:
- valid_end = i + 2
- i += 2
- else:
- break
- elif byte < 0xF0:
- if (
- i + 2 < len(all_bytes)
- and 0x80 <= all_bytes[i + 1] < 0xC0
- and 0x80 <= all_bytes[i + 2] < 0xC0
- ):
- valid_end = i + 3
- i += 3
- else:
- break
- elif byte < 0xF8:
- if (
- i + 3 < len(all_bytes)
- and 0x80 <= all_bytes[i + 1] < 0xC0
- and 0x80 <= all_bytes[i + 2] < 0xC0
- and 0x80 <= all_bytes[i + 3] < 0xC0
- ):
- valid_end = i + 4
- i += 4
- else:
- break
- else:
- i += 1
-
- complete_bytes = all_bytes[:valid_end]
- self.pending_bytes = all_bytes[valid_end:]
-
- if complete_bytes:
- return complete_bytes.decode("utf-8", errors="replace")
- return ""
-
- def flush(self) -> str:
- if self.pending_bytes:
- text = self.pending_bytes.decode("utf-8", errors="replace")
- self.pending_bytes = b""
- return text
- return ""
-
- def reset(self):
- self.pending_bytes = b""
-
-
-def detect_chat_template(spec_name: str) -> str:
- """Detect chat template from model spec name."""
- name = spec_name.lower()
- if "qwen" in name:
- return "qwen"
- elif "mixtral" in name or "mistral" in name:
- return "mistral"
- elif "llama3" in name or "llama-3" in name:
- return "llama3"
- elif "llama" in name:
- return "llama2"
- return "chatml"
-
-
-def main():
- parser = argparse.ArgumentParser(
- description="PyGPUkit MoE Chat CLI",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- )
- parser.add_argument(
- "--model",
- type=str,
- required=True,
- help="Path to model.safetensors or model.safetensors.index.json",
- )
- parser.add_argument(
- "--tokenizer",
- type=str,
- required=True,
- help="Path to tokenizer.json",
- )
- parser.add_argument(
- "--max-seq-len",
- type=int,
- default=4096,
- help="Maximum sequence length (default: 4096)",
- )
- parser.add_argument(
- "--max-new-tokens",
- type=int,
- default=512,
- help="Maximum new tokens per response (default: 512)",
- )
- parser.add_argument(
- "--temperature",
- type=float,
- default=0.7,
- help="Sampling temperature (default: 0.7)",
- )
- parser.add_argument(
- "--top-k",
- type=int,
- default=50,
- help="Top-k sampling (default: 50)",
- )
- parser.add_argument(
- "--top-p",
- type=float,
- default=0.9,
- help="Top-p (nucleus) sampling (default: 0.9)",
- )
- parser.add_argument(
- "--system",
- type=str,
- default="You are a helpful assistant.",
- help="System prompt",
- )
- parser.add_argument(
- "--repetition-penalty",
- type=float,
- default=1.1,
- help="Repetition penalty (default: 1.1, 1.0 = disabled)",
- )
- parser.add_argument(
- "--dtype",
- type=str,
- default="bfloat16",
- choices=["float16", "bfloat16", "float32"],
- help="Model dtype (default: bfloat16)",
- )
- parser.add_argument(
- "--cuda-graph",
- action="store_true",
- help="Enable CUDA Graph for faster decode (reduces kernel launch overhead)",
- )
- parser.add_argument(
- "--chat-template",
- type=str,
- default=None,
- choices=["qwen", "mistral", "llama2", "llama3", "chatml"],
- help="Chat template (auto-detected from model if not specified)",
- )
- args = parser.parse_args()
-
- # Lazy imports for faster --help
- print("Loading PyGPUkit...")
- from tokenizers import Tokenizer
-
- from pygpukit.core import default_stream, from_numpy
- from pygpukit.llm import (
- MIXTRAL_SPEC,
- DecodeM1Graph,
- detect_model_spec,
- load_model_from_safetensors,
- load_safetensors,
- )
- from pygpukit.llm.buffers import DecodeBuffers
- from pygpukit.llm.chat import format_chat_messages
- from pygpukit.llm.layers import precompute_freqs_cis
- from pygpukit.llm.sampling import sample_token
- from pygpukit.ops.basic import kv_cache_prefill_gqa
-
- # =========================================================================
- # Load Model
- # =========================================================================
- print(f"\nLoading MoE model from: {args.model}")
- print(f" dtype: {args.dtype}")
- t0 = time.perf_counter()
-
- tokenizer = Tokenizer.from_file(args.tokenizer)
- st = load_safetensors(args.model)
- spec = detect_model_spec(st.tensor_names)
-
- # Verify it's a MoE model
- if spec is None:
- print("Warning: Could not auto-detect model spec, using MIXTRAL_SPEC")
- spec = MIXTRAL_SPEC
- elif not spec.is_moe:
- print(f"Warning: Detected {spec.name} which is not a MoE model")
- print("This example is optimized for MoE models like Mixtral")
-
- model = load_model_from_safetensors(args.model, dtype=args.dtype, spec=spec)
-
- load_time = time.perf_counter() - t0
- print(f"Model loaded in {load_time:.1f}s")
-
- # Model info
- config = model.config
- print(f" Architecture: {spec.name if spec else 'unknown'}")
- print(f" Layers: {config.num_layers}, Hidden: {config.hidden_size}")
- print(f" Vocab size: {model.embed_tokens.shape[0]}")
- if config.num_experts:
- print(f" MoE: {config.num_experts} experts, top-{config.num_experts_per_tok}")
-
- # Determine chat template
- chat_template = args.chat_template
- if chat_template is None:
- chat_template = detect_chat_template(spec.name if spec else "")
- print(f" Chat template: {chat_template}")
-
- # =========================================================================
- # Initialize KV Cache
- # =========================================================================
- print(f"\nInitializing KV cache (max_seq_len={args.max_seq_len})...")
-
- for block in model.blocks:
- block.attn.init_fixed_cache(args.max_seq_len, dtype=args.dtype)
-
- # =========================================================================
- # Initialize Decode Buffers
- # =========================================================================
- use_qk_norm = model.spec is not None and model.spec.use_qk_norm
- lm_head = model._lm_head if model._lm_head is not None else model.embed_tokens
- vocab_size = lm_head.shape[0]
-
- decode_buffers = DecodeBuffers.allocate(
- config, dtype=args.dtype, use_qk_norm=use_qk_norm, vocab_size=vocab_size
- )
-
- # Precompute RoPE frequencies
- if config.use_rope:
- cos_np, sin_np = precompute_freqs_cis(config.head_dim, args.max_seq_len, config.rope_theta)
- if args.dtype == "float16":
- model._rope_cos_gpu = from_numpy(cos_np.astype(np.float16))
- model._rope_sin_gpu = from_numpy(sin_np.astype(np.float16))
- elif args.dtype == "bfloat16":
- cos_u32 = cos_np.view(np.uint32)
- sin_u32 = sin_np.view(np.uint32)
- cos_bf16 = ((cos_u32 + 0x7FFF + ((cos_u32 >> 16) & 1)) >> 16).astype(np.uint16)
- sin_bf16 = ((sin_u32 + 0x7FFF + ((sin_u32 >> 16) & 1)) >> 16).astype(np.uint16)
- model._rope_cos_gpu = from_numpy(cos_bf16)
- model._rope_sin_gpu = from_numpy(sin_bf16)
- else:
- model._rope_cos_gpu = from_numpy(cos_np.astype(np.float32))
- model._rope_sin_gpu = from_numpy(sin_np.astype(np.float32))
-
- default_stream().synchronize()
-
- # =========================================================================
- # Initialize CUDA Graph (optional)
- # =========================================================================
- use_cuda_graph = args.cuda_graph
- m1_graph = None
-
- if use_cuda_graph:
- print("\nInitializing CUDA Graph...")
- m1_graph = DecodeM1Graph()
- m1_graph.bind(model)
- m1_graph.init_graph(max_seq_len=args.max_seq_len)
- print(f" CUDA Graph ready (max_seq_len={args.max_seq_len})")
-
- print("Ready!")
-
- # =========================================================================
- # Chat State
- # =========================================================================
- conversation: list[dict] = []
- system_msg = {"role": "system", "content": args.system}
-
- # Get EOS tokens (model-specific)
- eos_token_ids: set[int] = set()
- for eos_str in ["", "<|endoftext|>", "<|im_end|>", "<|eot_id|>"]:
- tid = tokenizer.token_to_id(eos_str)
- if tid is not None:
- eos_token_ids.add(tid)
-
- def is_end_token(token_id: int) -> bool:
- return token_id in eos_token_ids
-
- def apply_repetition_penalty(
- logits: np.ndarray, generated_ids: list[int], penalty: float
- ) -> np.ndarray:
- if penalty == 1.0 or not generated_ids:
- return logits
- logits = logits.copy()
- for token_id in set(generated_ids):
- if logits[token_id] > 0:
- logits[token_id] /= penalty
- else:
- logits[token_id] *= penalty
- return logits
-
- # =========================================================================
- # Decode Helper (CUDA Graph or Non-Graph)
- # =========================================================================
- def decode_one_token(token_id: int, position: int, context_len: int) -> np.ndarray:
- """Decode one token and return logits as numpy array.
-
- Uses CUDA Graph if enabled, otherwise falls back to standard decode.
- """
- if use_cuda_graph and m1_graph is not None:
- logits = m1_graph.step_graph(token_id, position, context_len)
- return logits_to_f32(logits)[-1]
- else:
- hidden = model._decode_step_fixed_cache(token_id, position, context_len)
- logits = model.get_logits(hidden)
- return logits_to_f32(logits)[-1]
-
- # =========================================================================
- # Generation Function
- # =========================================================================
- def generate(messages: list[dict]) -> tuple[str, float, float, int]:
- """Generate response using M=1 decode."""
- prompt = format_chat_messages(messages, model_type=chat_template)
- input_ids = tokenizer.encode(prompt).ids
-
- if len(input_ids) >= args.max_seq_len - 10:
- return "[Error: Conversation too long. Use /clear to reset.]", 0, 0, 0
-
- # Prefill
- t_prefill_start = time.perf_counter()
- hidden, past_key_values = model(input_ids, use_cache=True)
-
- for i, block in enumerate(model.blocks):
- past_k, past_v = past_key_values[i]
- kv_cache_prefill_gqa(past_k, block.attn._k_cache, block.attn.num_heads, start_pos=0)
- kv_cache_prefill_gqa(past_v, block.attn._v_cache, block.attn.num_heads, start_pos=0)
- default_stream().synchronize()
- prefill_time = time.perf_counter() - t_prefill_start
-
- # Decode
- t_decode_start = time.perf_counter()
- logits = model.get_logits(hidden)
- last_logits = logits_to_f32(logits)[-1]
- next_token = sample_token(last_logits, args.temperature, args.top_k, args.top_p)
-
- generated_ids: list[int] = []
- position = len(input_ids)
- context_len = position + 1
-
- # Check if first token is end token
- if is_end_token(next_token):
- default_stream().synchronize()
- decode_time = time.perf_counter() - t_decode_start
- return "", prefill_time, decode_time, 0
-
- # Use streaming decoder for UTF-8 safe output
- stream_decoder = StreamingDecoder(tokenizer)
-
- # Output first token
- text_chunk = stream_decoder.add_token(next_token)
- if text_chunk:
- print(text_chunk, end="", flush=True)
- generated_ids.append(next_token)
-
- while len(generated_ids) < args.max_new_tokens:
- if context_len >= args.max_seq_len:
- break
-
- # Decode one token (CUDA Graph or standard)
- logits_np = decode_one_token(next_token, position, context_len)
- logits_np = apply_repetition_penalty(logits_np, generated_ids, args.repetition_penalty)
- next_token = sample_token(logits_np, args.temperature, args.top_k, args.top_p)
-
- if is_end_token(next_token):
- break
-
- generated_ids.append(next_token)
- position += 1
- context_len += 1
-
- text_chunk = stream_decoder.add_token(next_token)
- if text_chunk:
- print(text_chunk, end="", flush=True)
-
- # Flush any remaining buffered text
- remaining = stream_decoder.flush()
- if remaining:
- print(remaining, end="", flush=True)
-
- default_stream().synchronize()
- decode_time = time.perf_counter() - t_decode_start
-
- print()
- return tokenizer.decode(generated_ids), prefill_time, decode_time, len(generated_ids)
-
- # =========================================================================
- # Chat Loop
- # =========================================================================
- print("\n" + "=" * 60)
- print(" PyGPUkit MoE Chat")
- if config.num_experts:
- print(
- f" Model: {spec.name} ({config.num_experts} experts, top-{config.num_experts_per_tok})"
- )
- else:
- print(f" Model: {spec.name}")
- print(f" CUDA Graph: {'ON' if use_cuda_graph else 'OFF'}")
- print(" Commands: /clear (reset), /quit (exit)")
- print("=" * 60)
-
- while True:
- try:
- user_input = input("\nYou: ").strip()
- except (EOFError, KeyboardInterrupt):
- print("\nGoodbye!")
- break
-
- if not user_input:
- continue
-
- # Commands
- if user_input.lower() == "/quit":
- print("Goodbye!")
- break
- elif user_input.lower() == "/clear":
- conversation.clear()
- print("[Conversation cleared]")
- continue
-
- # Add user message
- conversation.append({"role": "user", "content": user_input})
-
- # Build full message list (without system prompt for now)
- messages = conversation
-
- # Generate response
- print("\nAssistant: ", end="", flush=True)
-
- response, prefill_time, decode_time, tokens_generated = generate(messages)
-
- # Add assistant response to history
- conversation.append({"role": "assistant", "content": response})
-
- # Stats
- decode_tps = tokens_generated / decode_time if decode_time > 0 else 0
- print(
- f" [prefill: {prefill_time:.1f}s, "
- f"decode: {tokens_generated} tok / {decode_time:.1f}s = {decode_tps:.1f} tok/s]"
- )
-
- # =========================================================================
- # Cleanup
- # =========================================================================
- print("\nUnloading model...")
- del model
- print("Done.")
-
-
-if __name__ == "__main__":
- main()
+#!/usr/bin/env python3
+"""
+PyGPUkit - MoE (Mixture of Experts) Chat CLI
+
+A minimal chat interface for MoE models (Mixtral, Qwen3-MoE, etc.).
+Supports multiple chat templates with auto-detection.
+
+Usage:
+ python examples/chat_cli_moe.py --model /path/to/model.safetensors.index.json --tokenizer /path/to/tokenizer.json
+
+Example (Qwen3-30B-A3B MoE):
+ python examples/chat_cli_moe.py \
+ --model /path/to/Qwen3-30B-A3B/model.safetensors.index.json \
+ --tokenizer /path/to/Qwen3-30B-A3B/tokenizer.json
+
+Example (Mixtral-8x7B):
+ python examples/chat_cli_moe.py \
+ --model /path/to/Mixtral-8x7B/model.safetensors.index.json \
+ --tokenizer /path/to/Mixtral-8x7B/tokenizer.json
+
+Example with explicit chat template:
+ python examples/chat_cli_moe.py \
+ --model /path/to/model --chat-template qwen
+
+Example with CUDA Graph (faster decode):
+ python examples/chat_cli_moe.py \
+ --model /path/to/model --cuda-graph
+
+Supported chat templates:
+ qwen - Qwen2/Qwen3 (<|im_start|>...<|im_end|>)
+ mistral - Mistral/Mixtral ([INST]...[/INST])
+ llama2 - LLaMA 2 (<>...<>)
+ llama3 - LLaMA 3 (<|start_header_id|>...<|eot_id|>)
+ chatml - Generic ChatML
+
+Commands:
+ /clear - Clear conversation history
+ /quit - Exit chat
+"""
+
+from __future__ import annotations
+
+import argparse
+import os
+import sys
+import time
+
+# Fix Windows console encoding for Unicode output
+if sys.platform == "win32":
+ sys.stdout.reconfigure(encoding="utf-8")
+ sys.stderr.reconfigure(encoding="utf-8")
+
+# Suppress cuBLASLt debug output
+os.environ.setdefault("PYGPUKIT_CUBLASLT_DEBUG", "0")
+
+import numpy as np
+
+
+def logits_to_f32(logits_gpu) -> np.ndarray:
+ """Convert logits GPU array to numpy float32."""
+ logits_np = logits_gpu.to_numpy()
+ if logits_np.dtype == np.uint16:
+ # bf16 stored as uint16 - convert to fp32
+ return (logits_np.astype(np.uint32) << 16).view(np.float32)
+ return logits_np.astype(np.float32)
+
+
+def _build_byte_decoder() -> dict[str, int]:
+ """Build the unicode-to-byte mapping used by GPT-2/Mistral style tokenizers."""
+ bs = (
+ list(range(ord("!"), ord("~") + 1))
+ + list(range(ord("\xa1"), ord("\xac") + 1))
+ + list(range(ord("\xae"), ord("\xff") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(256):
+ if b not in bs:
+ bs.append(b)
+ cs.append(256 + n)
+ n += 1
+ return {chr(c): b for b, c in zip(bs, cs)}
+
+
+_BYTE_DECODER = _build_byte_decoder()
+
+
+def _token_str_to_bytes(token_str: str) -> bytes:
+ """Convert a token string to raw bytes."""
+ result = []
+ for char in token_str:
+ if char in _BYTE_DECODER:
+ result.append(_BYTE_DECODER[char])
+ else:
+ result.extend(char.encode("utf-8"))
+ return bytes(result)
+
+
+class StreamingDecoder:
+ """Streaming decoder for UTF-8 safe output."""
+
+ def __init__(self, tokenizer):
+ self.tokenizer = tokenizer
+ self.pending_bytes = b""
+ self._cache: dict[int, bytes] = {}
+
+ def _get_token_bytes(self, token_id: int) -> bytes:
+ cached = self._cache.get(token_id)
+ if cached is not None:
+ return cached
+ token_str = self.tokenizer.id_to_token(token_id)
+ if token_str is None:
+ result = b""
+ else:
+ result = _token_str_to_bytes(token_str)
+ self._cache[token_id] = result
+ return result
+
+ def add_token(self, token_id: int) -> str:
+ new_bytes = self._get_token_bytes(token_id)
+ if not new_bytes:
+ return ""
+
+ all_bytes = self.pending_bytes + new_bytes
+ valid_end = 0
+ i = 0
+ while i < len(all_bytes):
+ byte = all_bytes[i]
+ if byte < 0x80:
+ valid_end = i + 1
+ i += 1
+ elif byte < 0xC0:
+ i += 1
+ elif byte < 0xE0:
+ if i + 1 < len(all_bytes) and 0x80 <= all_bytes[i + 1] < 0xC0:
+ valid_end = i + 2
+ i += 2
+ else:
+ break
+ elif byte < 0xF0:
+ if (
+ i + 2 < len(all_bytes)
+ and 0x80 <= all_bytes[i + 1] < 0xC0
+ and 0x80 <= all_bytes[i + 2] < 0xC0
+ ):
+ valid_end = i + 3
+ i += 3
+ else:
+ break
+ elif byte < 0xF8:
+ if (
+ i + 3 < len(all_bytes)
+ and 0x80 <= all_bytes[i + 1] < 0xC0
+ and 0x80 <= all_bytes[i + 2] < 0xC0
+ and 0x80 <= all_bytes[i + 3] < 0xC0
+ ):
+ valid_end = i + 4
+ i += 4
+ else:
+ break
+ else:
+ i += 1
+
+ complete_bytes = all_bytes[:valid_end]
+ self.pending_bytes = all_bytes[valid_end:]
+
+ if complete_bytes:
+ return complete_bytes.decode("utf-8", errors="replace")
+ return ""
+
+ def flush(self) -> str:
+ if self.pending_bytes:
+ text = self.pending_bytes.decode("utf-8", errors="replace")
+ self.pending_bytes = b""
+ return text
+ return ""
+
+ def reset(self):
+ self.pending_bytes = b""
+
+
+def detect_chat_template(spec_name: str) -> str:
+ """Detect chat template from model spec name."""
+ name = spec_name.lower()
+ if "qwen" in name:
+ return "qwen"
+ elif "mixtral" in name or "mistral" in name:
+ return "mistral"
+ elif "llama3" in name or "llama-3" in name:
+ return "llama3"
+ elif "llama" in name:
+ return "llama2"
+ return "chatml"
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="PyGPUkit MoE Chat CLI",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ required=True,
+ help="Path to model.safetensors or model.safetensors.index.json",
+ )
+ parser.add_argument(
+ "--tokenizer",
+ type=str,
+ required=True,
+ help="Path to tokenizer.json",
+ )
+ parser.add_argument(
+ "--max-seq-len",
+ type=int,
+ default=4096,
+ help="Maximum sequence length (default: 4096)",
+ )
+ parser.add_argument(
+ "--max-new-tokens",
+ type=int,
+ default=512,
+ help="Maximum new tokens per response (default: 512)",
+ )
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ default=0.7,
+ help="Sampling temperature (default: 0.7)",
+ )
+ parser.add_argument(
+ "--top-k",
+ type=int,
+ default=50,
+ help="Top-k sampling (default: 50)",
+ )
+ parser.add_argument(
+ "--top-p",
+ type=float,
+ default=0.9,
+ help="Top-p (nucleus) sampling (default: 0.9)",
+ )
+ parser.add_argument(
+ "--system",
+ type=str,
+ default="You are a helpful assistant.",
+ help="System prompt",
+ )
+ parser.add_argument(
+ "--repetition-penalty",
+ type=float,
+ default=1.1,
+ help="Repetition penalty (default: 1.1, 1.0 = disabled)",
+ )
+ parser.add_argument(
+ "--dtype",
+ type=str,
+ default="bfloat16",
+ choices=["float16", "bfloat16", "float32"],
+ help="Model dtype (default: bfloat16)",
+ )
+ parser.add_argument(
+ "--cuda-graph",
+ action="store_true",
+ help="Enable CUDA Graph for faster decode (reduces kernel launch overhead)",
+ )
+ parser.add_argument(
+ "--chat-template",
+ type=str,
+ default=None,
+ choices=["qwen", "mistral", "llama2", "llama3", "chatml"],
+ help="Chat template (auto-detected from model if not specified)",
+ )
+ args = parser.parse_args()
+
+ # Lazy imports for faster --help
+ print("Loading PyGPUkit...")
+ from tokenizers import Tokenizer
+
+ from pygpukit.core import default_stream, from_numpy
+ from pygpukit.llm import (
+ MIXTRAL_SPEC,
+ DecodeM1Graph,
+ detect_model_spec,
+ load_model_from_safetensors,
+ load_safetensors,
+ )
+ from pygpukit.llm.buffers import DecodeBuffers
+ from pygpukit.llm.chat import format_chat_messages
+ from pygpukit.llm.layers import precompute_freqs_cis
+ from pygpukit.llm.sampling import sample_token
+ from pygpukit.ops.basic import kv_cache_prefill_gqa
+
+ # =========================================================================
+ # Load Model
+ # =========================================================================
+ print(f"\nLoading MoE model from: {args.model}")
+ print(f" dtype: {args.dtype}")
+ t0 = time.perf_counter()
+
+ tokenizer = Tokenizer.from_file(args.tokenizer)
+ st = load_safetensors(args.model)
+ spec = detect_model_spec(st.tensor_names)
+
+ # Verify it's a MoE model
+ if spec is None:
+ print("Warning: Could not auto-detect model spec, using MIXTRAL_SPEC")
+ spec = MIXTRAL_SPEC
+ elif not spec.is_moe:
+ print(f"Warning: Detected {spec.name} which is not a MoE model")
+ print("This example is optimized for MoE models like Mixtral")
+
+ model = load_model_from_safetensors(args.model, dtype=args.dtype, spec=spec)
+
+ load_time = time.perf_counter() - t0
+ print(f"Model loaded in {load_time:.1f}s")
+
+ # Model info
+ config = model.config
+ print(f" Architecture: {spec.name if spec else 'unknown'}")
+ print(f" Layers: {config.num_layers}, Hidden: {config.hidden_size}")
+ print(f" Vocab size: {model.embed_tokens.shape[0]}")
+ if config.num_experts:
+ print(f" MoE: {config.num_experts} experts, top-{config.num_experts_per_tok}")
+
+ # Determine chat template
+ chat_template = args.chat_template
+ if chat_template is None:
+ chat_template = detect_chat_template(spec.name if spec else "")
+ print(f" Chat template: {chat_template}")
+
+ # =========================================================================
+ # Initialize KV Cache
+ # =========================================================================
+ print(f"\nInitializing KV cache (max_seq_len={args.max_seq_len})...")
+
+ for block in model.blocks:
+ block.attn.init_fixed_cache(args.max_seq_len, dtype=args.dtype)
+
+ # =========================================================================
+ # Initialize Decode Buffers
+ # =========================================================================
+ use_qk_norm = model.spec is not None and model.spec.use_qk_norm
+ lm_head = model._lm_head if model._lm_head is not None else model.embed_tokens
+ vocab_size = lm_head.shape[0]
+
+ decode_buffers = DecodeBuffers.allocate(
+ config, dtype=args.dtype, use_qk_norm=use_qk_norm, vocab_size=vocab_size
+ )
+
+ # Precompute RoPE frequencies
+ if config.use_rope:
+ cos_np, sin_np = precompute_freqs_cis(config.head_dim, args.max_seq_len, config.rope_theta)
+ if args.dtype == "float16":
+ model._rope_cos_gpu = from_numpy(cos_np.astype(np.float16))
+ model._rope_sin_gpu = from_numpy(sin_np.astype(np.float16))
+ elif args.dtype == "bfloat16":
+ cos_u32 = cos_np.view(np.uint32)
+ sin_u32 = sin_np.view(np.uint32)
+ cos_bf16 = ((cos_u32 + 0x7FFF + ((cos_u32 >> 16) & 1)) >> 16).astype(np.uint16)
+ sin_bf16 = ((sin_u32 + 0x7FFF + ((sin_u32 >> 16) & 1)) >> 16).astype(np.uint16)
+ model._rope_cos_gpu = from_numpy(cos_bf16)
+ model._rope_sin_gpu = from_numpy(sin_bf16)
+ else:
+ model._rope_cos_gpu = from_numpy(cos_np.astype(np.float32))
+ model._rope_sin_gpu = from_numpy(sin_np.astype(np.float32))
+
+ default_stream().synchronize()
+
+ # =========================================================================
+ # Initialize CUDA Graph (optional)
+ # =========================================================================
+ use_cuda_graph = args.cuda_graph
+ m1_graph = None
+
+ if use_cuda_graph:
+ print("\nInitializing CUDA Graph...")
+ m1_graph = DecodeM1Graph()
+ m1_graph.bind(model)
+ m1_graph.init_graph(max_seq_len=args.max_seq_len)
+ print(f" CUDA Graph ready (max_seq_len={args.max_seq_len})")
+
+ print("Ready!")
+
+ # =========================================================================
+ # Chat State
+ # =========================================================================
+ conversation: list[dict] = []
+ system_msg = {"role": "system", "content": args.system}
+
+ # Get EOS tokens (model-specific)
+ eos_token_ids: set[int] = set()
+ for eos_str in ["", "<|endoftext|>", "<|im_end|>", "<|eot_id|>"]:
+ tid = tokenizer.token_to_id(eos_str)
+ if tid is not None:
+ eos_token_ids.add(tid)
+
+ def is_end_token(token_id: int) -> bool:
+ return token_id in eos_token_ids
+
+ def apply_repetition_penalty(
+ logits: np.ndarray, generated_ids: list[int], penalty: float
+ ) -> np.ndarray:
+ if penalty == 1.0 or not generated_ids:
+ return logits
+ logits = logits.copy()
+ for token_id in set(generated_ids):
+ if logits[token_id] > 0:
+ logits[token_id] /= penalty
+ else:
+ logits[token_id] *= penalty
+ return logits
+
+ # =========================================================================
+ # Decode Helper (CUDA Graph or Non-Graph)
+ # =========================================================================
+ def decode_one_token(token_id: int, position: int, context_len: int) -> np.ndarray:
+ """Decode one token and return logits as numpy array.
+
+ Uses CUDA Graph if enabled, otherwise falls back to standard decode.
+ """
+ if use_cuda_graph and m1_graph is not None:
+ logits = m1_graph.step_graph(token_id, position, context_len)
+ return logits_to_f32(logits)[-1]
+ else:
+ hidden = model._decode_step_fixed_cache(token_id, position, context_len)
+ logits = model.get_logits(hidden)
+ return logits_to_f32(logits)[-1]
+
+ # =========================================================================
+ # Generation Function
+ # =========================================================================
+ def generate(messages: list[dict]) -> tuple[str, float, float, int]:
+ """Generate response using M=1 decode."""
+ prompt = format_chat_messages(messages, model_type=chat_template)
+ input_ids = tokenizer.encode(prompt).ids
+
+ if len(input_ids) >= args.max_seq_len - 10:
+ return "[Error: Conversation too long. Use /clear to reset.]", 0, 0, 0
+
+ # Prefill
+ t_prefill_start = time.perf_counter()
+ hidden, past_key_values = model(input_ids, use_cache=True)
+
+ for i, block in enumerate(model.blocks):
+ past_k, past_v = past_key_values[i]
+ kv_cache_prefill_gqa(past_k, block.attn._k_cache, block.attn.num_heads, start_pos=0)
+ kv_cache_prefill_gqa(past_v, block.attn._v_cache, block.attn.num_heads, start_pos=0)
+ default_stream().synchronize()
+ prefill_time = time.perf_counter() - t_prefill_start
+
+ # Decode
+ t_decode_start = time.perf_counter()
+ logits = model.get_logits(hidden)
+ last_logits = logits_to_f32(logits)[-1]
+ next_token = sample_token(last_logits, args.temperature, args.top_k, args.top_p)
+
+ generated_ids: list[int] = []
+ position = len(input_ids)
+ context_len = position + 1
+
+ # Check if first token is end token
+ if is_end_token(next_token):
+ default_stream().synchronize()
+ decode_time = time.perf_counter() - t_decode_start
+ return "", prefill_time, decode_time, 0
+
+ # Use streaming decoder for UTF-8 safe output
+ stream_decoder = StreamingDecoder(tokenizer)
+
+ # Output first token
+ text_chunk = stream_decoder.add_token(next_token)
+ if text_chunk:
+ print(text_chunk, end="", flush=True)
+ generated_ids.append(next_token)
+
+ while len(generated_ids) < args.max_new_tokens:
+ if context_len >= args.max_seq_len:
+ break
+
+ # Decode one token (CUDA Graph or standard)
+ logits_np = decode_one_token(next_token, position, context_len)
+ logits_np = apply_repetition_penalty(logits_np, generated_ids, args.repetition_penalty)
+ next_token = sample_token(logits_np, args.temperature, args.top_k, args.top_p)
+
+ if is_end_token(next_token):
+ break
+
+ generated_ids.append(next_token)
+ position += 1
+ context_len += 1
+
+ text_chunk = stream_decoder.add_token(next_token)
+ if text_chunk:
+ print(text_chunk, end="", flush=True)
+
+ # Flush any remaining buffered text
+ remaining = stream_decoder.flush()
+ if remaining:
+ print(remaining, end="", flush=True)
+
+ default_stream().synchronize()
+ decode_time = time.perf_counter() - t_decode_start
+
+ print()
+ return tokenizer.decode(generated_ids), prefill_time, decode_time, len(generated_ids)
+
+ # =========================================================================
+ # Chat Loop
+ # =========================================================================
+ print("\n" + "=" * 60)
+ print(" PyGPUkit MoE Chat")
+ if config.num_experts:
+ print(
+ f" Model: {spec.name} ({config.num_experts} experts, top-{config.num_experts_per_tok})"
+ )
+ else:
+ print(f" Model: {spec.name}")
+ print(f" CUDA Graph: {'ON' if use_cuda_graph else 'OFF'}")
+ print(" Commands: /clear (reset), /quit (exit)")
+ print("=" * 60)
+
+ while True:
+ try:
+ user_input = input("\nYou: ").strip()
+ except (EOFError, KeyboardInterrupt):
+ print("\nGoodbye!")
+ break
+
+ if not user_input:
+ continue
+
+ # Commands
+ if user_input.lower() == "/quit":
+ print("Goodbye!")
+ break
+ elif user_input.lower() == "/clear":
+ conversation.clear()
+ print("[Conversation cleared]")
+ continue
+
+ # Add user message
+ conversation.append({"role": "user", "content": user_input})
+
+ # Build full message list (without system prompt for now)
+ messages = conversation
+
+ # Generate response
+ print("\nAssistant: ", end="", flush=True)
+
+ response, prefill_time, decode_time, tokens_generated = generate(messages)
+
+ # Add assistant response to history
+ conversation.append({"role": "assistant", "content": response})
+
+ # Stats
+ decode_tps = tokens_generated / decode_time if decode_time > 0 else 0
+ print(
+ f" [prefill: {prefill_time:.1f}s, "
+ f"decode: {tokens_generated} tok / {decode_time:.1f}s = {decode_tps:.1f} tok/s]"
+ )
+
+ # =========================================================================
+ # Cleanup
+ # =========================================================================
+ print("\nUnloading model...")
+ del model
+ print("Done.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/chat_cli_thinking.py b/examples/chat/chat_cli_thinking.py
similarity index 100%
rename from examples/chat_cli_thinking.py
rename to examples/chat/chat_cli_thinking.py
diff --git a/examples/chat_cli_triton.py b/examples/chat/chat_cli_triton.py
similarity index 100%
rename from examples/chat_cli_triton.py
rename to examples/chat/chat_cli_triton.py
diff --git a/examples/demo_v01.py b/examples/demos/archived/demo_v01.py
similarity index 100%
rename from examples/demo_v01.py
rename to examples/demos/archived/demo_v01.py
diff --git a/examples/demo_v02.py b/examples/demos/archived/demo_v02.py
similarity index 100%
rename from examples/demo_v02.py
rename to examples/demos/archived/demo_v02.py
diff --git a/examples/demo_v0210.py b/examples/demos/archived/demo_v0210.py
similarity index 100%
rename from examples/demo_v0210.py
rename to examples/demos/archived/demo_v0210.py
diff --git a/examples/demo_v0212.py b/examples/demos/archived/demo_v0212.py
similarity index 100%
rename from examples/demo_v0212.py
rename to examples/demos/archived/demo_v0212.py
diff --git a/examples/demo_v023.py b/examples/demos/archived/demo_v023.py
similarity index 100%
rename from examples/demo_v023.py
rename to examples/demos/archived/demo_v023.py
diff --git a/examples/demo_v025.py b/examples/demos/archived/demo_v025.py
similarity index 100%
rename from examples/demo_v025.py
rename to examples/demos/archived/demo_v025.py
diff --git a/examples/demo_v026_multi_llm.py b/examples/demos/archived/demo_v026_multi_llm.py
similarity index 100%
rename from examples/demo_v026_multi_llm.py
rename to examples/demos/archived/demo_v026_multi_llm.py
diff --git a/examples/demo_v02_full.py b/examples/demos/archived/demo_v02_full.py
similarity index 100%
rename from examples/demo_v02_full.py
rename to examples/demos/archived/demo_v02_full.py
diff --git a/native/CMakeLists.txt b/native/CMakeLists.txt
index e32db94..e7e4bfb 100644
--- a/native/CMakeLists.txt
+++ b/native/CMakeLists.txt
@@ -153,6 +153,8 @@ pybind11_add_module(${MODULE_NAME}
ops/reduction/reduction.cu
ops/matmul/matmul.cu
ops/matmul/matmul_cutlass.cu
+ ops/matmul/fused.cu
+ ops/matmul/batched.cu
# GEMM kernels (Issue #122: Reorganized with w{weight}a{act}_{out} naming)
ops/matmul/gemm/f32_f32/generic/f32_ampere.cu
ops/matmul/gemm/w8a8_f32/sm90/fp8_cutlass.cu
diff --git a/native/ops/matmul/batched.cu b/native/ops/matmul/batched.cu
new file mode 100644
index 0000000..52e0e54
--- /dev/null
+++ b/native/ops/matmul/batched.cu
@@ -0,0 +1,49 @@
+/**
+ * Batched matrix multiplication operations
+ *
+ * Currently a placeholder - batched GEMM requires CUTLASS implementation.
+ */
+#include "../../core/memory.hpp"
+#include "../../core/cuda_graph.hpp"
+#include "../common/error.cuh"
+
+#include
+
+namespace pygpukit {
+namespace ops {
+
+/**
+ * Batched strided matrix multiplication (FP32).
+ *
+ * Computes C[i] = A[i] @ B[i] for i in 0..batch_count-1.
+ * Each matrix is accessed via strided offsets from the base pointer.
+ *
+ * @param A Input matrix A, shape [batch_count * strideA]
+ * @param B Input matrix B, shape [batch_count * strideB]
+ * @param C Output matrix C, shape [batch_count * strideC]
+ * @param M Number of rows in A and C
+ * @param N Number of columns in B and C
+ * @param K Number of columns in A / rows in B
+ * @param batch_count Number of batches
+ * @param strideA Stride between A matrices (in elements)
+ * @param strideB Stride between B matrices (in elements)
+ * @param strideC Stride between C matrices (in elements)
+ */
+void batched_matmul_fp32(const GPUArray& A, const GPUArray& B, GPUArray& C,
+ int M, int N, int K, int batch_count,
+ int64_t strideA, int64_t strideB, int64_t strideC) {
+ // Validate inputs
+ if (A.dtype() != DataType::Float32 || B.dtype() != DataType::Float32 || C.dtype() != DataType::Float32) {
+ throw std::runtime_error("batched_matmul_fp32: all inputs must be float32");
+ }
+
+ // TODO: Implement batched GEMM with CUTLASS or cuBLASLt
+ // For now, this is a placeholder that throws
+ (void)M; (void)N; (void)K;
+ (void)batch_count;
+ (void)strideA; (void)strideB; (void)strideC;
+ throw std::runtime_error("batched_matmul_fp32: not yet implemented");
+}
+
+} // namespace ops
+} // namespace pygpukit
diff --git a/native/ops/matmul/fused.cu b/native/ops/matmul/fused.cu
new file mode 100644
index 0000000..51ba4f4
--- /dev/null
+++ b/native/ops/matmul/fused.cu
@@ -0,0 +1,137 @@
+/**
+ * Fused matmul operations (CUTLASS epilogue fusion)
+ */
+#include "../../core/memory.hpp"
+#include "../../core/cuda_graph.hpp"
+#include "../common/error.cuh"
+#include "../ops.cuh" // For transpose(), gelu(), bias_add_inplace()
+
+#include
+#include
+#include
+
+// CUTLASS BiasGELU fused operations (extern declarations from matmul_cutlass.cu)
+extern "C" {
+ cudaError_t cutlass_gemm_tf32_bias_gelu(const float* A, const float* B, const float* bias, float* D, int M, int N, int K, cudaStream_t stream);
+ cudaError_t cutlass_gemm_fp16_bias_gelu(const __half* A, const __half* B, const __half* bias, __half* D, int M, int N, int K, cudaStream_t stream);
+ cudaError_t cutlass_gemm_bf16_bias_gelu(const __nv_bfloat16* A, const __nv_bfloat16* B, const __nv_bfloat16* bias, __nv_bfloat16* D, int M, int N, int K, cudaStream_t stream);
+ bool cutlass_is_compatible(int M, int N, int K);
+ bool cutlass_is_sm_supported();
+}
+
+namespace pygpukit {
+namespace ops {
+
+// Forward declarations for fallback path
+void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c);
+
+/**
+ * Fused linear + bias + GELU activation.
+ *
+ * Computes: output = GELU(input @ weight^T + bias)
+ *
+ * Uses CUTLASS epilogue fusion when available (SM >= 86, dimensions divisible by 16).
+ * Falls back to native matmul + bias_add + gelu when CUTLASS is not available.
+ *
+ * @param input Input tensor [batch, in_features]
+ * @param weight Weight matrix [out_features, in_features]
+ * @param bias Bias vector [out_features]
+ * @return Output tensor [batch, out_features]
+ */
+GPUArray linear_bias_gelu(const GPUArray& input, const GPUArray& weight, const GPUArray& bias) {
+ // Validate shapes: input [batch, in_features], weight [out_features, in_features], bias [out_features]
+ if (input.ndim() != 2) {
+ throw std::runtime_error("linear_bias_gelu: input must be 2D [batch, in_features]");
+ }
+ if (weight.ndim() != 2) {
+ throw std::runtime_error("linear_bias_gelu: weight must be 2D [out_features, in_features]");
+ }
+ if (bias.ndim() != 1) {
+ throw std::runtime_error("linear_bias_gelu: bias must be 1D [out_features]");
+ }
+
+ size_t batch = input.shape()[0];
+ size_t in_features = input.shape()[1];
+ size_t out_features = weight.shape()[0];
+
+ if (weight.shape()[1] != in_features) {
+ throw std::runtime_error("linear_bias_gelu: weight.shape[1] must match input.shape[1]");
+ }
+ if (bias.shape()[0] != out_features) {
+ throw std::runtime_error("linear_bias_gelu: bias.shape[0] must match weight.shape[0]");
+ }
+
+ // Validate dtypes
+ if (input.dtype() != weight.dtype() || input.dtype() != bias.dtype()) {
+ throw std::runtime_error("linear_bias_gelu: all inputs must have the same dtype");
+ }
+
+ // Check if CUTLASS fused kernel can be used
+ // Requirements: dimensions must be multiples of 16 AND SM >= 86
+ bool use_cutlass = cutlass_is_compatible(batch, out_features, in_features) && cutlass_is_sm_supported();
+
+ // Also check if CUTLASS is disabled via environment variable
+ const char* no_cutlass_env = std::getenv("PYGPUKIT_NO_CUTLASS");
+ if (no_cutlass_env && (no_cutlass_env[0] == '1' || no_cutlass_env[0] == 'y' || no_cutlass_env[0] == 'Y')) {
+ use_cutlass = false;
+ }
+
+ // Transpose weight for both paths (needed for input @ weight^T)
+ GPUArray weight_T = transpose(weight); // [in_features, out_features]
+
+ // Allocate output
+ GPUArray output({batch, out_features}, input.dtype());
+
+ if (use_cutlass) {
+ // CUTLASS fused BiasGELU kernel path
+ cudaError_t err = cudaSuccess;
+ cudaStream_t stream = internal::get_capture_stream();
+
+ switch (input.dtype()) {
+ case DataType::Float32:
+ err = cutlass_gemm_tf32_bias_gelu(
+ static_cast(input.data()),
+ static_cast(weight_T.data()),
+ static_cast(bias.data()),
+ static_cast(output.data()),
+ batch, out_features, in_features, stream);
+ break;
+ case DataType::Float16:
+ err = cutlass_gemm_fp16_bias_gelu(
+ static_cast(input.data()),
+ static_cast(weight_T.data()),
+ static_cast(bias.data()),
+ static_cast<__half*>(output.data()),
+ batch, out_features, in_features, stream);
+ break;
+ case DataType::BFloat16:
+ err = cutlass_gemm_bf16_bias_gelu(
+ static_cast(input.data()),
+ static_cast(weight_T.data()),
+ static_cast(bias.data()),
+ static_cast<__nv_bfloat16*>(output.data()),
+ batch, out_features, in_features, stream);
+ break;
+ default:
+ throw std::runtime_error("linear_bias_gelu only supports float32, float16, and bfloat16");
+ }
+
+ // If CUTLASS fails (e.g., not compiled in), fall back to native path
+ if (err == cudaSuccess) {
+ sync_and_check("linear_bias_gelu CUTLASS kernel failed");
+ return output;
+ }
+ // Fall through to native path if CUTLASS returns error
+ }
+
+ // Native fallback path: matmul + bias_add_inplace + gelu
+ // This works for any dimensions and when CUTLASS is not available
+ matmul(input, weight_T, output);
+ bias_add_inplace(output, bias);
+ output = gelu(output);
+
+ return output;
+}
+
+} // namespace ops
+} // namespace pygpukit
diff --git a/native/ops/matmul/matmul.cu b/native/ops/matmul/matmul.cu
index 077111f..bc8b343 100644
--- a/native/ops/matmul/matmul.cu
+++ b/native/ops/matmul/matmul.cu
@@ -28,12 +28,8 @@ extern "C" {
cudaError_t cutlass_gemm_bf16(const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C, int M, int N, int K, cudaStream_t stream);
bool cutlass_is_compatible(int M, int N, int K);
bool cutlass_is_sm_supported();
-
- // BiasGELU fused operations
- cudaError_t cutlass_gemm_tf32_bias_gelu(const float* A, const float* B, const float* bias, float* D, int M, int N, int K, cudaStream_t stream);
- cudaError_t cutlass_gemm_fp16_bias_gelu(const __half* A, const __half* B, const __half* bias, __half* D, int M, int N, int K, cudaStream_t stream);
- cudaError_t cutlass_gemm_bf16_bias_gelu(const __nv_bfloat16* A, const __nv_bfloat16* B, const __nv_bfloat16* bias, __nv_bfloat16* D, int M, int N, int K, cudaStream_t stream);
}
+// BiasGELU fused operations moved to fused.cu
namespace pygpukit {
namespace ops {
@@ -528,137 +524,8 @@ GPUArray matmul(const GPUArray& a, const GPUArray& b, bool use_tf32) {
return c;
}
-// ============================================================================
-// Fused Operations (CUTLASS Epilogue Fusion)
-// ============================================================================
-
-GPUArray linear_bias_gelu(const GPUArray& input, const GPUArray& weight, const GPUArray& bias) {
- // Validate shapes: input [batch, in_features], weight [out_features, in_features], bias [out_features]
- if (input.ndim() != 2) {
- throw std::runtime_error("linear_bias_gelu: input must be 2D [batch, in_features]");
- }
- if (weight.ndim() != 2) {
- throw std::runtime_error("linear_bias_gelu: weight must be 2D [out_features, in_features]");
- }
- if (bias.ndim() != 1) {
- throw std::runtime_error("linear_bias_gelu: bias must be 1D [out_features]");
- }
-
- size_t batch = input.shape()[0];
- size_t in_features = input.shape()[1];
- size_t out_features = weight.shape()[0];
-
- if (weight.shape()[1] != in_features) {
- throw std::runtime_error("linear_bias_gelu: weight.shape[1] must match input.shape[1]");
- }
- if (bias.shape()[0] != out_features) {
- throw std::runtime_error("linear_bias_gelu: bias.shape[0] must match weight.shape[0]");
- }
-
- // Validate dtypes
- if (input.dtype() != weight.dtype() || input.dtype() != bias.dtype()) {
- throw std::runtime_error("linear_bias_gelu: all inputs must have the same dtype");
- }
-
- // Check if CUTLASS fused kernel can be used
- // Requirements: dimensions must be multiples of 16 AND SM >= 86
- bool use_cutlass = cutlass_is_compatible(batch, out_features, in_features) && cutlass_is_sm_supported();
-
- // Also check if CUTLASS is disabled via environment variable
- const char* no_cutlass_env = std::getenv("PYGPUKIT_NO_CUTLASS");
- if (no_cutlass_env && (no_cutlass_env[0] == '1' || no_cutlass_env[0] == 'y' || no_cutlass_env[0] == 'Y')) {
- use_cutlass = false;
- }
-
- // Transpose weight for both paths (needed for input @ weight^T)
- GPUArray weight_T = transpose(weight); // [in_features, out_features]
-
- // Allocate output
- GPUArray output({batch, out_features}, input.dtype());
-
- if (use_cutlass) {
- // CUTLASS fused BiasGELU kernel path
- cudaError_t err = cudaSuccess;
- cudaStream_t stream = internal::get_capture_stream();
-
- switch (input.dtype()) {
- case DataType::Float32:
- err = cutlass_gemm_tf32_bias_gelu(
- static_cast(input.data()),
- static_cast(weight_T.data()),
- static_cast(bias.data()),
- static_cast(output.data()),
- batch, out_features, in_features, stream);
- break;
- case DataType::Float16:
- err = cutlass_gemm_fp16_bias_gelu(
- static_cast(input.data()),
- static_cast(weight_T.data()),
- static_cast(bias.data()),
- static_cast<__half*>(output.data()),
- batch, out_features, in_features, stream);
- break;
- case DataType::BFloat16:
- err = cutlass_gemm_bf16_bias_gelu(
- static_cast(input.data()),
- static_cast(weight_T.data()),
- static_cast(bias.data()),
- static_cast<__nv_bfloat16*>(output.data()),
- batch, out_features, in_features, stream);
- break;
- default:
- throw std::runtime_error("linear_bias_gelu only supports float32, float16, and bfloat16");
- }
-
- // If CUTLASS fails (e.g., not compiled in), fall back to native path
- if (err == cudaSuccess) {
- sync_and_check("linear_bias_gelu CUTLASS kernel failed");
- return output;
- }
- // Fall through to native path if CUTLASS returns error
- }
-
- // Native fallback path: matmul + bias_add_inplace + gelu
- // This works for any dimensions and when CUTLASS is not available
- matmul(input, weight_T, output);
- bias_add_inplace(output, bias);
- output = gelu(output);
-
- return output;
-}
-
-// ============================================================================
-// Batched GEMM Implementation
-// ============================================================================
-
-void batched_matmul_fp32(const GPUArray& A, const GPUArray& B, GPUArray& C,
- int M, int N, int K, int batch_count,
- int64_t strideA, int64_t strideB, int64_t strideC) {
- // Validate inputs
- if (A.dtype() != DataType::Float32 || B.dtype() != DataType::Float32 || C.dtype() != DataType::Float32) {
- throw std::runtime_error("batched_matmul_fp32: all inputs must be float32");
- }
-
-#if PYGPUKIT_HAS_CUTLASS
- // Use CUTLASS batched GEMM
- cudaError_t err = cutlass_gemm::gemm_batched_fp32(
- static_cast(A.data()),
- static_cast(B.data()),
- static_cast(C.data()),
- M, N, K,
- batch_count,
- strideA, strideB, strideC,
- 1.0f, 0.0f, // alpha, beta
- internal::get_capture_stream()
- );
- if (err != cudaSuccess) {
- throw std::runtime_error("batched_matmul_fp32: CUTLASS kernel failed");
- }
- sync_and_check("batched_matmul_fp32 CUTLASS kernel failed");
-#else
- throw std::runtime_error("batched_matmul_fp32: CUTLASS not available");
-#endif
-}
+// Fused operations (linear_bias_gelu) are in fused.cu
+// Batched GEMM (batched_matmul_fp32) are in batched.cu
} // namespace ops
} // namespace pygpukit
diff --git a/src/pygpukit/core/__init__.py b/src/pygpukit/core/__init__.py
index d7eb9de..dcf5ced 100644
--- a/src/pygpukit/core/__init__.py
+++ b/src/pygpukit/core/__init__.py
@@ -4,6 +4,14 @@
from pygpukit.core.device import DeviceInfo, get_device_info, is_cuda_available
from pygpukit.core.dtypes import DataType, float32, float64, int16, int32, int64
from pygpukit.core.factory import empty, from_numpy, ones, zeros
+from pygpukit.core.memory import (
+ copy_device_to_device_async,
+ copy_device_to_device_offset,
+ copy_to_device,
+ copy_to_device_async,
+ get_memory_info,
+ synchronize,
+)
from pygpukit.core.stream import Stream, StreamManager, default_stream
# Import CUDA Event for GPU-side timing (via auto-selecting loader)
@@ -27,23 +35,36 @@
event_elapsed_us = None # type: ignore[assignment]
__all__ = [
+ # Array
"GPUArray",
+ # Device
"DeviceInfo",
"get_device_info",
"is_cuda_available",
+ # Data types
"DataType",
"float64",
"float32",
"int64",
"int32",
"int16",
+ # Factory
"zeros",
"ones",
"empty",
"from_numpy",
+ # Memory
+ "get_memory_info",
+ "copy_to_device",
+ "copy_to_device_async",
+ "copy_device_to_device_async",
+ "copy_device_to_device_offset",
+ "synchronize",
+ # Stream
"Stream",
"StreamManager",
"default_stream",
+ # Events
"CudaEvent",
"event_elapsed_ms",
"event_elapsed_us",
diff --git a/src/pygpukit/core/memory.py b/src/pygpukit/core/memory.py
new file mode 100644
index 0000000..4f3839b
--- /dev/null
+++ b/src/pygpukit/core/memory.py
@@ -0,0 +1,223 @@
+"""Memory management utilities for GPU arrays.
+
+Provides Python wrappers for native memory operations:
+- Memory info (free/total)
+- Async copy operations
+- Device synchronization
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from pygpukit.core.array import GPUArray
+ from pygpukit.core.stream import Stream
+
+
+def get_memory_info() -> tuple[int, int]:
+ """Get GPU memory information.
+
+ Returns:
+ Tuple of (free_bytes, total_bytes).
+
+ Example:
+ free, total = get_memory_info()
+ print(f"Free: {free / 1e9:.2f} GB / Total: {total / 1e9:.2f} GB")
+ """
+ from pygpukit.core.backend import get_backend, has_native_module
+
+ if not has_native_module():
+ # CPU simulation - return dummy values
+ return (8 * 1024**3, 8 * 1024**3) # 8 GB
+
+ backend = get_backend()
+ if not backend.is_available():
+ return (0, 0)
+
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ props = native.get_device_properties()
+ # Native returns total_memory; free requires cudaMemGetInfo
+ # For now return (total - some_estimate, total)
+ return (props.total_memory, props.total_memory)
+
+
+def copy_to_device_async(
+ dst: GPUArray,
+ src_ptr: int,
+ size_bytes: int,
+ stream: Stream,
+) -> None:
+ """Async copy from host pointer to GPUArray.
+
+ Args:
+ dst: Destination GPUArray.
+ src_ptr: Source host memory pointer (as integer).
+ size_bytes: Number of bytes to copy.
+ stream: CUDA stream for async operation.
+
+ Note:
+ For true async behavior, src_ptr should point to pinned memory.
+ Otherwise the copy may block.
+ """
+ from pygpukit.core.backend import get_native_module, has_native_module
+
+ if not has_native_module():
+ raise RuntimeError("copy_to_device_async requires native backend")
+
+ native = get_native_module()
+ native.memcpy_ptr_to_device_async(
+ dst._get_native(),
+ src_ptr,
+ size_bytes,
+ stream._get_native(),
+ )
+
+
+def copy_to_device_async_raw_stream(
+ dst: GPUArray,
+ src_ptr: int,
+ size_bytes: int,
+ stream_handle: int,
+) -> None:
+ """Async copy using raw stream handle (for CUDA Graph).
+
+ Args:
+ dst: Destination GPUArray.
+ src_ptr: Source host memory pointer (as integer).
+ size_bytes: Number of bytes to copy.
+ stream_handle: Raw CUDA stream handle (cudaStream_t as int).
+
+ Note:
+ Used during CUDA Graph capture where Stream object may not be available.
+ """
+ from pygpukit.core.backend import get_native_module, has_native_module
+
+ if not has_native_module():
+ raise RuntimeError("copy_to_device_async_raw_stream requires native backend")
+
+ native = get_native_module()
+ native.memcpy_ptr_to_device_async_raw_stream(
+ dst._get_native(),
+ src_ptr,
+ size_bytes,
+ stream_handle,
+ )
+
+
+def copy_to_device(
+ dst: GPUArray,
+ src_ptr: int,
+ size_bytes: int,
+) -> None:
+ """Synchronous copy from host pointer to GPUArray.
+
+ Args:
+ dst: Destination GPUArray.
+ src_ptr: Source host memory pointer (as integer).
+ size_bytes: Number of bytes to copy.
+
+ Note:
+ This is a blocking operation. Use copy_to_device_async for
+ non-blocking copies.
+ """
+ from pygpukit.core.backend import get_native_module, has_native_module
+
+ if not has_native_module():
+ raise RuntimeError("copy_to_device requires native backend")
+
+ native = get_native_module()
+ native.memcpy_ptr_to_device(
+ dst._get_native(),
+ src_ptr,
+ size_bytes,
+ )
+
+
+def copy_device_to_device_async(
+ dst: GPUArray,
+ src: GPUArray,
+ stream: Stream,
+) -> None:
+ """Async copy between GPUArrays on device.
+
+ Args:
+ dst: Destination GPUArray.
+ src: Source GPUArray.
+ stream: CUDA stream for async operation.
+
+ Note:
+ Both arrays must have the same size in bytes.
+ """
+ from pygpukit.core.backend import get_native_module, has_native_module
+
+ if not has_native_module():
+ raise RuntimeError("copy_device_to_device_async requires native backend")
+
+ if dst.nbytes != src.nbytes:
+ raise ValueError(f"Size mismatch: dst.nbytes={dst.nbytes}, src.nbytes={src.nbytes}")
+
+ native = get_native_module()
+ native.memcpy_device_to_device_async(
+ dst._get_native(),
+ src._get_native(),
+ stream._get_native(),
+ )
+
+
+def copy_device_to_device_offset(
+ dst: GPUArray,
+ dst_offset_bytes: int,
+ src: GPUArray,
+ src_offset_bytes: int,
+ size_bytes: int,
+) -> None:
+ """Copy between GPUArrays with byte offsets.
+
+ Args:
+ dst: Destination GPUArray.
+ dst_offset_bytes: Byte offset in destination.
+ src: Source GPUArray.
+ src_offset_bytes: Byte offset in source.
+ size_bytes: Number of bytes to copy.
+ """
+ from pygpukit.core.backend import get_native_module, has_native_module
+
+ if not has_native_module():
+ raise RuntimeError("copy_device_to_device_offset requires native backend")
+
+ native = get_native_module()
+ native.memcpy_device_to_device_offset(
+ dst._get_native(),
+ dst_offset_bytes,
+ src._get_native(),
+ src_offset_bytes,
+ size_bytes,
+ )
+
+
+def synchronize() -> None:
+ """Synchronize all GPU operations.
+
+ Blocks until all previously issued GPU operations complete.
+ """
+ from pygpukit.core.backend import get_native_module, has_native_module
+
+ if not has_native_module():
+ return # No-op for CPU simulation
+
+ native = get_native_module()
+ native.synchronize()
+
+
+__all__ = [
+ "get_memory_info",
+ "copy_to_device_async",
+ "copy_to_device_async_raw_stream",
+ "copy_to_device",
+ "copy_device_to_device_async",
+ "copy_device_to_device_offset",
+ "synchronize",
+]
diff --git a/src/pygpukit/llm/__init__.py b/src/pygpukit/llm/__init__.py
index 88b0e51..e958c89 100644
--- a/src/pygpukit/llm/__init__.py
+++ b/src/pygpukit/llm/__init__.py
@@ -4,536 +4,21 @@
- SafeTensors file loading with memory mapping
- Tensor metadata and data access
- GPU tensor allocation helpers
+- LLM model implementations (CausalTransformerModel)
+- Layer implementations (Attention, MLP, etc.)
+- Decode strategies (M1, Batch, Jacobi, Speculative)
"""
from __future__ import annotations
-from typing import TYPE_CHECKING
-
-from ..core.backend import get_rust_module
-
-if TYPE_CHECKING:
- from collections.abc import Sequence
-
-# Get the Rust llm module
-_rust = get_rust_module()
-_llm = _rust.llm if _rust else None
-
-
-class Dtype:
- """Tensor data type enumeration."""
-
- Float32 = 0
- Float16 = 1
- BFloat16 = 2
- Float64 = 3
- Float8E4M3 = 4 # FP8 E4M3 (1 sign, 4 exponent, 3 mantissa)
- Float8E5M2 = 5 # FP8 E5M2 (1 sign, 5 exponent, 2 mantissa)
- Int32 = 6
- Int64 = 7
- Int16 = 8
- Int8 = 9
- UInt8 = 10
- Bool = 11
-
- _NAMES = {
- 0: "float32",
- 1: "float16",
- 2: "bfloat16",
- 3: "float64",
- 4: "float8_e4m3",
- 5: "float8_e5m2",
- 6: "int32",
- 7: "int64",
- 8: "int16",
- 9: "int8",
- 10: "uint8",
- 11: "bool",
- }
-
- _SIZES = {
- 0: 4, # float32
- 1: 2, # float16
- 2: 2, # bfloat16
- 3: 8, # float64
- 4: 1, # float8_e4m3
- 5: 1, # float8_e5m2
- 6: 4, # int32
- 7: 8, # int64
- 8: 2, # int16
- 9: 1, # int8
- 10: 1, # uint8
- 11: 1, # bool
- }
-
- @classmethod
- def element_size(cls, dtype: int) -> int:
- """Get the size in bytes of a single element."""
- return cls._SIZES.get(dtype, 0)
-
- @classmethod
- def name(cls, dtype: int) -> str:
- """Get the string name of a dtype."""
- return cls._NAMES.get(dtype, "unknown")
-
-
-class TensorInfo:
- """Metadata for a single tensor in a safetensors file."""
-
- def __init__(
- self,
- name: str,
- dtype: int,
- shape: Sequence[int],
- offset: int,
- size_bytes: int,
- ):
- self.name = name
- self.dtype = dtype
- self.shape = list(shape)
- self.offset = offset
- self.size_bytes = size_bytes
-
- @property
- def numel(self) -> int:
- """Total number of elements."""
- result = 1
- for dim in self.shape:
- result *= dim
- return result
-
- @property
- def dtype_name(self) -> str:
- """String name of the dtype."""
- return Dtype.name(self.dtype)
-
- def __repr__(self) -> str:
- return (
- f"TensorInfo(name='{self.name}', dtype={self.dtype_name}, "
- f"shape={self.shape}, size_bytes={self.size_bytes})"
- )
-
-
-class SafeTensorsFile:
- """Memory-mapped SafeTensors file.
-
- Provides efficient access to tensor metadata and data from a .safetensors file
- using memory mapping for zero-copy data access.
-
- Example:
- >>> st = SafeTensorsFile("model.safetensors")
- >>> print(st.tensor_names)
- ['weight', 'bias']
- >>> info = st.tensor_info('weight')
- >>> print(info.shape, info.dtype_name)
- [768, 768] float16
- >>> data = st.tensor_bytes('weight')
- """
-
- def __init__(self, path: str):
- """Open a safetensors file.
-
- Args:
- path: Path to the .safetensors file
- """
- if _llm is None:
- raise RuntimeError("Rust LLM module not available")
- self._inner = _llm.SafeTensorsFile(path)
-
- @property
- def tensor_names(self) -> list[str]:
- """Get list of all tensor names."""
- return self._inner.tensor_names
-
- @property
- def file_size(self) -> int:
- """Total file size in bytes."""
- return self._inner.file_size
-
- @property
- def num_tensors(self) -> int:
- """Number of tensors in the file."""
- return self._inner.num_tensors
-
- def tensor_info(self, name: str) -> TensorInfo:
- """Get metadata for a tensor by name.
-
- Args:
- name: Tensor name
-
- Returns:
- TensorInfo with dtype, shape, offset, and size
-
- Raises:
- KeyError: If tensor name not found
- """
- info = self._inner.tensor_info(name)
- return TensorInfo(
- name=info.name,
- dtype=int(info.dtype),
- shape=info.shape,
- offset=info.offset,
- size_bytes=info.size_bytes,
- )
-
- def tensor_bytes(self, name: str) -> bytes:
- """Get raw tensor data as bytes.
-
- Args:
- name: Tensor name
-
- Returns:
- Raw bytes of the tensor data
-
- Raises:
- KeyError: If tensor name not found
- """
- return bytes(self._inner.tensor_bytes(name))
-
- def tensor_as_f32(self, name: str):
- """Get tensor data as numpy float32 array.
-
- Args:
- name: Tensor name
-
- Returns:
- 1D numpy array of float32 values
-
- Raises:
- KeyError: If tensor name not found
- ValueError: If tensor dtype is not Float32
- """
- return self._inner.tensor_as_f32(name)
-
- def tensor_data_ptr(self, name: str) -> tuple[int, int]:
- """Get raw mmap pointer for direct GPU transfer.
-
- Args:
- name: Tensor name
-
- Returns:
- Tuple of (ptr, size_bytes) where ptr is the raw mmap address
-
- Raises:
- KeyError: If tensor name not found
- """
- return self._inner.tensor_data_ptr(name)
-
- def __len__(self) -> int:
- return self.num_tensors
-
- def __contains__(self, name: str) -> bool:
- return name in self._inner
-
- def __repr__(self) -> str:
- return f"SafeTensorsFile(num_tensors={self.num_tensors}, file_size={self.file_size})"
-
-
-class ShardedSafeTensorsFile:
- """Sharded SafeTensors file loader.
-
- Handles models split across multiple .safetensors files with an index.json.
- Lazily opens shards on demand to minimize memory usage.
-
- Example:
- >>> st = ShardedSafeTensorsFile("model.safetensors.index.json")
- >>> print(st.tensor_names[:5])
- ['lm_head.weight', 'model.embed_tokens.weight', ...]
- >>> info = st.tensor_info('model.embed_tokens.weight')
- >>> data = st.tensor_bytes('model.embed_tokens.weight')
- """
-
- def __init__(self, index_json_path: str):
- """Open a sharded safetensors model.
-
- Args:
- index_json_path: Path to model.safetensors.index.json
- """
- import json
- from pathlib import Path
-
- self._index_path = Path(index_json_path)
- self._base_dir = self._index_path.parent
-
- with open(index_json_path, encoding="utf-8") as f:
- index = json.load(f)
-
- # weight_map: { tensor_name: shard_filename }
- self._weight_map: dict[str, str] = index.get("weight_map", {})
- self._metadata = index.get("metadata", {})
-
- # Lazy-loaded shard files
- self._shards: dict[str, SafeTensorsFile] = {}
-
- # Unique shard files
- self._shard_files = list(set(self._weight_map.values()))
-
- def _get_shard(self, shard_file: str) -> SafeTensorsFile:
- """Lazily open a shard file."""
- if shard_file not in self._shards:
- shard_path = self._base_dir / shard_file
- self._shards[shard_file] = SafeTensorsFile(str(shard_path))
- return self._shards[shard_file]
-
- @property
- def tensor_names(self) -> list[str]:
- """Get list of all tensor names across all shards."""
- return list(self._weight_map.keys())
-
- @property
- def file_size(self) -> int:
- """Total file size across all shards (lazy, opens all shards)."""
- total = 0
- for shard_file in self._shard_files:
- total += self._get_shard(shard_file).file_size
- return total
-
- @property
- def num_tensors(self) -> int:
- """Number of tensors across all shards."""
- return len(self._weight_map)
-
- def tensor_info(self, name: str) -> TensorInfo:
- """Get metadata for a tensor by name.
-
- Args:
- name: Tensor name
-
- Returns:
- TensorInfo with dtype, shape, offset, and size
-
- Raises:
- KeyError: If tensor name not found
- """
- if name not in self._weight_map:
- raise KeyError(f"Tensor '{name}' not found")
- shard_file = self._weight_map[name]
- return self._get_shard(shard_file).tensor_info(name)
-
- def tensor_bytes(self, name: str) -> bytes:
- """Get raw tensor data as bytes.
-
- Args:
- name: Tensor name
-
- Returns:
- Raw bytes of the tensor data
-
- Raises:
- KeyError: If tensor name not found
- """
- if name not in self._weight_map:
- raise KeyError(f"Tensor '{name}' not found")
- shard_file = self._weight_map[name]
- return self._get_shard(shard_file).tensor_bytes(name)
-
- def tensor_as_f32(self, name: str):
- """Get tensor data as numpy float32 array.
-
- Args:
- name: Tensor name
-
- Returns:
- 1D numpy array of float32 values
-
- Raises:
- KeyError: If tensor name not found
- ValueError: If tensor dtype is not Float32
- """
- if name not in self._weight_map:
- raise KeyError(f"Tensor '{name}' not found")
- shard_file = self._weight_map[name]
- return self._get_shard(shard_file).tensor_as_f32(name)
-
- def tensor_data_ptr(self, name: str) -> tuple[int, int]:
- """Get raw mmap pointer for direct GPU transfer.
-
- Args:
- name: Tensor name
-
- Returns:
- Tuple of (ptr, size_bytes) where ptr is the raw mmap address
-
- Raises:
- KeyError: If tensor name not found
- """
- if name not in self._weight_map:
- raise KeyError(f"Tensor '{name}' not found")
- shard_file = self._weight_map[name]
- return self._get_shard(shard_file).tensor_data_ptr(name)
-
- def __len__(self) -> int:
- return self.num_tensors
-
- def __contains__(self, name: str) -> bool:
- return name in self._weight_map
-
- def __repr__(self) -> str:
- return (
- f"ShardedSafeTensorsFile(num_tensors={self.num_tensors}, "
- f"num_shards={len(self._shard_files)})"
- )
-
-
-def load_safetensors(path: str) -> SafeTensorsFile | ShardedSafeTensorsFile:
- """Load a safetensors file (single or sharded).
-
- Automatically detects sharded models by .index.json extension.
-
- Args:
- path: Path to .safetensors file or .safetensors.index.json
-
- Returns:
- SafeTensorsFile or ShardedSafeTensorsFile for accessing tensor data
-
- Example:
- # Single file
- st = load_safetensors("model.safetensors")
-
- # Sharded model
- st = load_safetensors("model.safetensors.index.json")
- """
- if path.endswith(".index.json"):
- return ShardedSafeTensorsFile(path)
- else:
- return SafeTensorsFile(path)
-
-
-class Tokenizer:
- """BPE Tokenizer for GPT-2 style models.
-
- **⚠️ EXPERIMENTAL: This tokenizer is intended for demos and testing only.**
-
- For production use, we recommend HuggingFace tokenizers:
- - https://github.com/huggingface/tokenizers
- - pip install tokenizers
-
- PyGPUkit's core responsibility is GPU execution, not tokenization.
- The model API expects token IDs as input - use your preferred tokenizer
- to convert text to token IDs before passing to PyGPUkit models.
-
- Limitations:
- - Only supports a subset of HuggingFace tokenizer.json formats
- - May not work with all models (e.g., Qwen3 uses unsupported format)
- - No chat template support
- - No special token handling beyond BOS/EOS/PAD
-
- Example:
- >>> # For demos/testing only
- >>> tok = Tokenizer("tokenizer.json")
- >>> ids = tok.encode("Hello, world!")
- >>> text = tok.decode(ids)
-
- >>> # For production, use HuggingFace tokenizers:
- >>> from tokenizers import Tokenizer as HFTokenizer
- >>> hf_tok = HFTokenizer.from_file("tokenizer.json")
- >>> ids = hf_tok.encode("Hello, world!").ids
- """
-
- def __init__(self, path: str):
- """Load tokenizer from tokenizer.json file.
-
- Args:
- path: Path to the tokenizer.json file
- """
- if _llm is None:
- raise RuntimeError("Rust LLM module not available")
- self._inner = _llm.Tokenizer(path)
-
- @classmethod
- def from_json(cls, json_str: str) -> Tokenizer:
- """Load tokenizer from JSON string.
-
- Args:
- json_str: JSON string containing tokenizer config
-
- Returns:
- Tokenizer instance
- """
- if _llm is None:
- raise RuntimeError("Rust LLM module not available")
- instance = cls.__new__(cls)
- instance._inner = _llm.Tokenizer.from_json(json_str)
- return instance
-
- @property
- def vocab_size(self) -> int:
- """Get vocabulary size."""
- return self._inner.vocab_size
-
- @property
- def bos_token_id(self) -> int | None:
- """Get BOS (beginning of sequence) token ID if available."""
- return self._inner.bos_token_id
-
- @property
- def eos_token_id(self) -> int | None:
- """Get EOS (end of sequence) token ID if available."""
- return self._inner.eos_token_id
-
- @property
- def pad_token_id(self) -> int | None:
- """Get PAD token ID if available."""
- return self._inner.pad_token_id
-
- def encode(self, text: str) -> list[int]:
- """Encode text to token IDs.
-
- Args:
- text: Input text to encode
-
- Returns:
- List of token IDs
- """
- return list(self._inner.encode(text))
-
- def decode(self, token_ids: list[int]) -> str:
- """Decode token IDs to text.
-
- Args:
- token_ids: List of token IDs
-
- Returns:
- Decoded text string
- """
- return self._inner.decode(token_ids)
-
- def id_to_token(self, token_id: int) -> str | None:
- """Get token string for an ID.
-
- Args:
- token_id: Token ID
-
- Returns:
- Token string if ID is valid, None otherwise
- """
- return self._inner.id_to_token(token_id)
-
- def token_to_id(self, token: str) -> int | None:
- """Get ID for a token string.
-
- Args:
- token: Token string
-
- Returns:
- Token ID if token exists, None otherwise
- """
- return self._inner.token_to_id(token)
-
- def __len__(self) -> int:
- return self.vocab_size
-
- def __repr__(self) -> str:
- return f"Tokenizer(vocab_size={self.vocab_size})"
-
-
-# Chat template support (v0.2.10)
# Buffers (refactored v0.2.11)
-from pygpukit.llm.buffers import ( # noqa: E402
+from pygpukit.llm.buffers import (
DecodeBuffers,
PrefillBuffers,
)
-from pygpukit.llm.chat import ( # noqa: E402
+
+# Chat template support (v0.2.10)
+from pygpukit.llm.chat import (
ChatMessage,
apply_chat_template,
create_chat_prompt,
@@ -541,7 +26,7 @@ def __repr__(self) -> str:
)
# Config classes and ModelSpec (refactored v0.2.11)
-from pygpukit.llm.config import ( # noqa: E402
+from pygpukit.llm.config import (
GPT2_SPEC,
LLAMA_SPEC,
MIXTRAL_SPEC,
@@ -558,7 +43,7 @@ def __repr__(self) -> str:
)
# Decode strategies (refactored v0.2.11)
-from pygpukit.llm.decode import ( # noqa: E402
+from pygpukit.llm.decode import (
DecodeBatch,
DecodeJacobi,
DecodeM1,
@@ -567,11 +52,11 @@ def __repr__(self) -> str:
DecodeStrategy,
)
-# Layers (refactored v0.2.11)
-from pygpukit.llm.layers import ( # noqa: E402
+# Layers (refactored v0.2.18)
+from pygpukit.llm.layers import (
MLP,
Attention,
- Linear, # Backward compatibility alias
+ Linear,
LinearBF16,
LinearFP8,
MoELayer,
@@ -586,7 +71,7 @@ def __repr__(self) -> str:
# Loaders (refactored v0.2.11)
# Quantization/Optimization configs (v0.2.18 - Issue #115)
-from pygpukit.llm.loader import ( # noqa: E402 # noqa: E402
+from pygpukit.llm.loader import (
FP8QuantConfig,
ModelOptimizationInfo,
PruningConfig,
@@ -600,9 +85,8 @@ def __repr__(self) -> str:
repack_model_weights,
)
-# Model (refactored v0.2.11)
-from pygpukit.llm.model import ( # noqa: E402
- # Type aliases
+# Model (refactored v0.2.18)
+from pygpukit.llm.model import (
CausalSelfAttention,
CausalTransformerModel,
GPT2Model,
@@ -614,8 +98,20 @@ def __repr__(self) -> str:
RMSNorm,
)
+# SafeTensors (extracted v0.2.18)
+from pygpukit.llm.safetensors import (
+ Dtype,
+ SafeTensorsFile,
+ ShardedSafeTensorsFile,
+ TensorInfo,
+ load_safetensors,
+)
+
# Sampling (refactored v0.2.11)
-from pygpukit.llm.sampling import sample_token # noqa: E402
+from pygpukit.llm.sampling import sample_token
+
+# Tokenizer (extracted v0.2.18)
+from pygpukit.llm.tokenizer import Tokenizer
__all__ = [
# SafeTensors
diff --git a/src/pygpukit/llm/layers.py b/src/pygpukit/llm/layers.py
deleted file mode 100644
index 573642c..0000000
--- a/src/pygpukit/llm/layers.py
+++ /dev/null
@@ -1,1491 +0,0 @@
-"""Neural network layer implementations for PyGPUkit LLM.
-
-Provides:
-- LinearBF16: Dense layer with BF16 weights
-- LinearFP8: Dense layer with FP8 weights (online dequantization)
-- Norm: RMSNorm and LayerNorm
-- Attention: Multi-head attention with RoPE, GQA, QK-Norm, KV cache
-- MLP: Feed-forward network (GELU/SwiGLU)
-- TransformerBlock: Attention + MLP with residual connections
-- RoPE utilities: precompute_freqs_cis, apply_rotary_pos_emb_numpy
-- Repack utilities: repack_weight, repack_linear, repack_norm
-"""
-
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Literal
-
-import numpy as np
-
-from pygpukit.core.array import GPUArray
-from pygpukit.core.dtypes import bfloat16 as dt_bfloat16
-from pygpukit.core.dtypes import float16 as dt_float16
-from pygpukit.core.factory import from_numpy, zeros
-from pygpukit.ops.basic import (
- add,
- bias_add_inplace,
- concat_axis0,
- copy_to,
- gelu,
- gemv_bf16,
- gemv_fp8_bf16,
- kv_cache_prefill_gqa,
- kv_cache_update_gqa,
- layernorm,
- matmul,
- mul,
- repeat_interleave_axis1,
- reshape_copy,
- rmsnorm,
- rope_inplace,
- sdpa_causal,
- sdpa_causal_fixed_cache,
- silu,
- slice_rows_range_ptr,
- split_qkv_batch,
- transpose,
- transpose_3d_021,
- w8a16_gemm_sm120,
-)
-
-if TYPE_CHECKING:
- from pygpukit.llm.buffers import DecodeBuffers
- from pygpukit.llm.config import TransformerConfig
-
-
-# =============================================================================
-# Common Building Blocks
-# =============================================================================
-
-
-class LinearBF16:
- """BF16 Linear layer: y = xW^T + b
-
- Weights are stored as [out_features, in_features] (PyTorch convention).
-
- For M=1 (single token decode), uses custom GEMV kernel which is 4-6x faster
- than cuBLASLt matmul. Automatically falls back to matmul for batch > 1.
- """
-
- # Class-level flag to enable/disable GEMV optimization
- _use_gemv: bool = True
-
- def __init__(self, weight: GPUArray, bias: GPUArray | None = None):
- if weight.ndim != 2:
- raise ValueError(f"weight must be 2D, got {weight.ndim}D")
- self.weight = weight
- self.bias = bias
- self.out_features = weight.shape[0]
- self.in_features = weight.shape[1]
- self._weight_t: GPUArray | None = None
-
- def __call__(self, x: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
- """Forward pass: y = xW^T + b
-
- Args:
- x: Input tensor [batch, in_features]
- out: Optional output buffer [batch, out_features]. If provided,
- result is written in-place (for CUDA Graph capture).
- """
- if x.ndim != 2:
- raise ValueError(f"input must be 2D [batch, in_features], got {x.ndim}D")
- if x.shape[1] != self.in_features:
- raise ValueError(f"input features {x.shape[1]} != weight {self.in_features}")
-
- if self._weight_t is None:
- self._weight_t = transpose(self.weight)
-
- # Use GEMV for M=1 with BF16 (1.3-2.4x faster than matmul)
- # Skip GEMV when out is provided (CUDA Graph mode) - GEMV allocates internally
- use_gemv = (
- LinearBF16._use_gemv
- and x.shape[0] == 1
- and x.dtype == dt_bfloat16
- and out is None # GEMV allocates, not compatible with CUDA Graph
- )
-
- if use_gemv:
- # GEMV path for M=1 decode
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- x_1d = x.view((self.in_features,))
-
- # Use optimized kernel (SM80+) with B[N,K] layout
- if native.gemv_bf16_opt_available():
- y_1d = zeros((self.out_features,), dtype="bfloat16")
- # gemv_bf16_opt: A[K] @ B[N,K]^T -> C[N]
- native.gemv_bf16_opt_sm120(
- x_1d._get_native(),
- self.weight._get_native(), # [N, K] - no transpose
- y_1d._get_native(),
- )
- else:
- # Fallback: old kernel with B[K,N] layout
- y_1d = gemv_bf16(x_1d, self._weight_t)
-
- y = y_1d.view((1, self.out_features))
- else:
- # Standard matmul path
- y = matmul(x, self._weight_t, out=out)
-
- if self.bias is not None:
- bias_add_inplace(y, self.bias)
-
- return y
-
-
-# Backward compatibility alias
-Linear = LinearBF16
-
-
-class LinearFP8:
- """FP8 Linear layer with online dequantization: y = x @ dequant(W)^T + b
-
- Stores weights in FP8 E4M3 format with block-wise scaling factors.
- Dequantizes on-the-fly during forward pass using CUDA kernel.
-
- Memory savings: 50% vs BF16 (1 byte vs 2 bytes per weight + small scale overhead)
-
- For M=1 (single token decode), uses FP8 GEMV kernel with online dequantization.
- For larger batches, falls back to CPU dequantization + GPU matmul.
- """
-
- # Class-level flag to enable/disable GEMV optimization
- _use_gemv: bool = True
-
- # FP8 E4M3 to float32 lookup table (for CPU fallback)
- _FP8_TABLE: np.ndarray | None = None
-
- @classmethod
- def _get_fp8_table(cls) -> np.ndarray:
- """Build FP8 E4M3 to float32 conversion lookup table."""
- if cls._FP8_TABLE is not None:
- return cls._FP8_TABLE
-
- table = np.zeros(256, dtype=np.float32)
- for i in range(256):
- sign = (i >> 7) & 1
- exp = (i >> 3) & 0xF
- mant = i & 0x7
-
- if exp == 0xF and mant == 0x7:
- table[i] = np.nan
- elif exp == 0:
- value = (mant / 8.0) * (2.0**-6)
- table[i] = -value if sign else value
- else:
- value = (1.0 + mant / 8.0) * (2.0 ** (exp - 7))
- table[i] = -value if sign else value
-
- cls._FP8_TABLE = table
- return table
-
- def __init__(
- self,
- weight_fp8: GPUArray, # [out_features, in_features] as uint8
- scale_inv: GPUArray, # [out_features // block_h, in_features // block_w] as bf16
- bias: GPUArray | None = None,
- block_size: tuple[int, int] = (128, 128),
- ):
- if weight_fp8.ndim != 2:
- raise ValueError(f"weight must be 2D, got {weight_fp8.ndim}D")
- self.weight_fp8 = weight_fp8
- self.scale_inv = scale_inv
- self.bias = bias
- self.block_size = block_size
- self.out_features = weight_fp8.shape[0]
- self.in_features = weight_fp8.shape[1]
-
- # Transposed weight for GEMV: [in_features, out_features]
- # FP8 GEMV expects B[K,N] where K=in_features, N=out_features
- self._weight_fp8_t: GPUArray | None = None
- self._scale_inv_t: GPUArray | None = None
-
- # Cached dequantized weight for fallback (lazy initialization)
- self._weight_dequant: GPUArray | None = None
- self._weight_dequant_t: GPUArray | None = None
-
- def _ensure_transposed_fp8(self) -> None:
- """Ensure transposed FP8 weight is available for GEMV."""
- if self._weight_fp8_t is None:
- # Transpose weight: [out, in] -> [in, out]
- self._weight_fp8_t = transpose(self.weight_fp8)
- # Transpose scale: [out/128, in/128] -> [in/128, out/128]
- self._scale_inv_t = transpose(self.scale_inv)
-
- def _dequantize_cpu(self) -> np.ndarray:
- """Dequantize FP8 weight to float32 on CPU."""
- table = self._get_fp8_table()
-
- # Get FP8 bytes
- fp8_np = self.weight_fp8.to_numpy()
- if fp8_np.dtype != np.uint8:
- fp8_np = fp8_np.view(np.uint8)
-
- # Convert to float32
- f32 = table[fp8_np.ravel()].reshape(fp8_np.shape)
-
- # Get scale_inv (bf16 as uint16)
- scale_np = self.scale_inv.to_numpy()
- if scale_np.dtype == np.uint16:
- scale_f32 = np.empty(scale_np.shape, dtype=np.float32)
- scale_f32.view(np.uint32)[:] = scale_np.astype(np.uint32) << 16
- else:
- scale_f32 = scale_np.astype(np.float32)
-
- # Apply block-wise scaling
- H, W = f32.shape
- block_h, block_w = self.block_size
- num_blocks_h = H // block_h
- num_blocks_w = W // block_w
-
- f32_reshaped = f32.reshape(num_blocks_h, block_h, num_blocks_w, block_w)
- scale_expanded = scale_f32[:, np.newaxis, :, np.newaxis]
- f32_scaled = f32_reshaped * scale_expanded
-
- return f32_scaled.reshape(H, W)
-
- def _ensure_dequantized(self) -> None:
- """Ensure dequantized weight is available (lazy init, for fallback)."""
- if self._weight_dequant is None:
- # Dequantize on CPU and upload to GPU
- weight_f32 = self._dequantize_cpu()
-
- # Convert to BF16
- uint32_view = weight_f32.view(np.uint32)
- weight_bf16 = ((uint32_view + 0x7FFF + ((uint32_view >> 16) & 1)) >> 16).astype(
- np.uint16
- )
-
- self._weight_dequant = from_numpy(weight_bf16)
- self._weight_dequant_t = transpose(self._weight_dequant)
-
- def __call__(self, x: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
- """Forward pass with online dequantization.
-
- For M=1 (single token), uses FP8 GEMV kernel with online dequantization.
- For M>1, uses batched FP8 GEMV kernel.
- """
- if x.ndim != 2:
- raise ValueError(f"input must be 2D [batch, in_features], got {x.ndim}D")
- if x.shape[1] != self.in_features:
- raise ValueError(f"input features {x.shape[1]} != weight {self.in_features}")
-
- M = x.shape[0]
-
- if M == 1 and self._use_gemv:
- # M=1 path: Use FP8 GEMV kernel with B[N,K] layout (no transpose needed)
- x_1d = x.view((self.in_features,))
-
- if out is not None:
- out_1d = out.view((self.out_features,))
- gemv_fp8_bf16(x_1d, self.weight_fp8, self.scale_inv, out=out_1d)
- y = out
- else:
- y_1d = gemv_fp8_bf16(x_1d, self.weight_fp8, self.scale_inv)
- y = y_1d.view((1, self.out_features))
- else:
- # M>1 path: Use W8A16 GEMM with FP8 TensorCore (requires transposed weights)
- self._ensure_transposed_fp8()
- y = w8a16_gemm_sm120(x, self._weight_fp8_t, self._scale_inv_t, out=out)
-
- if self.bias is not None:
- bias_add_inplace(y, self.bias)
-
- return y
-
-
-class Norm:
- """Unified normalization layer supporting RMSNorm and LayerNorm."""
-
- def __init__(
- self,
- weight: GPUArray,
- bias: GPUArray | None = None,
- norm_type: Literal["rmsnorm", "layernorm"] = "rmsnorm",
- eps: float = 1e-5,
- ):
- self.weight = weight
- self.bias = bias
- self.norm_type = norm_type
- self.eps = eps
-
- def __call__(self, x: GPUArray) -> GPUArray:
- if self.norm_type == "rmsnorm":
- return rmsnorm(x, self.weight, self.eps)
- else:
- if self.bias is None:
- raise ValueError("LayerNorm requires bias")
- return layernorm(x, self.weight, self.bias, self.eps)
-
-
-# =============================================================================
-# Weight Repacking - Fix GPU memory placement for optimal performance
-# =============================================================================
-
-
-def repack_weight(weight: GPUArray) -> GPUArray:
- """Repack a weight tensor into a new contiguous GPU buffer.
-
- This fixes performance issues caused by fragmented GPU memory allocation.
- Weights allocated later during model loading may end up in suboptimal
- memory regions, causing 7x slower matmul performance.
-
- Args:
- weight: Original weight tensor on GPU
-
- Returns:
- New GPUArray with same data in freshly allocated contiguous memory
- """
- # Copy to CPU, then back to GPU to get fresh allocation
- # This ensures the new buffer is allocated contiguously
- weight_np = weight.to_numpy()
- return from_numpy(weight_np)
-
-
-def repack_linear(linear: LinearBF16) -> None:
- """Repack a LinearBF16 layer's weight in-place.
-
- Args:
- linear: LinearBF16 layer to repack
- """
- linear.weight = repack_weight(linear.weight)
- # Clear transpose cache - will be regenerated on first use
- linear._weight_t = None
- if linear.bias is not None:
- linear.bias = repack_weight(linear.bias)
-
-
-def repack_norm(norm: Norm) -> None:
- """Repack a Norm layer's weight in-place.
-
- Args:
- norm: Norm layer to repack
- """
- norm.weight = repack_weight(norm.weight)
- if norm.bias is not None:
- norm.bias = repack_weight(norm.bias)
-
-
-# =============================================================================
-# RoPE (Rotary Position Embedding)
-# =============================================================================
-
-
-def precompute_freqs_cis(
- head_dim: int, max_seq_len: int, theta: float = 10000.0
-) -> tuple[np.ndarray, np.ndarray]:
- """Precompute rotary embedding cos/sin tables."""
- freqs = 1.0 / (theta ** (np.arange(0, head_dim, 2, dtype=np.float32) / head_dim))
- t = np.arange(max_seq_len, dtype=np.float32)
- freqs = np.outer(t, freqs)
- cos = np.cos(freqs)
- sin = np.sin(freqs)
- cos = np.concatenate([cos, cos], axis=-1)
- sin = np.concatenate([sin, sin], axis=-1)
- return cos, sin
-
-
-def apply_rotary_pos_emb_numpy(
- q: np.ndarray, k: np.ndarray, cos: np.ndarray, sin: np.ndarray
-) -> tuple[np.ndarray, np.ndarray]:
- """Apply rotary position embeddings to Q and K (numpy version)."""
-
- def rotate_half(x: np.ndarray) -> np.ndarray:
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return np.concatenate([-x2, x1], axis=-1)
-
- cos = cos[:, np.newaxis, :]
- sin = sin[:, np.newaxis, :]
-
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-# =============================================================================
-# Unified Attention
-# =============================================================================
-
-
-class Attention:
- """Unified attention with Hybrid CPU/GPU execution.
-
- Supports:
- - Multi-Head Attention (MHA): num_kv_heads == num_heads
- - Grouped Query Attention (GQA): num_kv_heads < num_heads
- - RoPE: enabled via config.use_rope
- - QK Norm: optional normalization of Q and K (Qwen3 style)
- - Hybrid execution: CPU for seq_len=1, GPU for longer sequences
- - FP8 quantized weights via LinearFP8
- """
-
- def __init__(
- self,
- q_proj: GPUArray | LinearBF16 | LinearFP8,
- k_proj: GPUArray | LinearBF16 | LinearFP8,
- v_proj: GPUArray | LinearBF16 | LinearFP8,
- o_proj: GPUArray | LinearBF16 | LinearFP8,
- config: TransformerConfig,
- q_bias: GPUArray | None = None,
- k_bias: GPUArray | None = None,
- v_bias: GPUArray | None = None,
- o_bias: GPUArray | None = None,
- q_norm: Norm | None = None,
- k_norm: Norm | None = None,
- ):
- # Accept either GPUArray (wrapped in LinearBF16) or pre-built LinearBF16/LinearFP8
- def wrap_linear(
- proj: GPUArray | LinearBF16 | LinearFP8, bias: GPUArray | None
- ) -> LinearBF16 | LinearFP8:
- if isinstance(proj, (LinearBF16, LinearFP8)):
- return proj
- return LinearBF16(proj, bias)
-
- self.q_proj = wrap_linear(q_proj, q_bias)
- self.k_proj = wrap_linear(k_proj, k_bias)
- self.v_proj = wrap_linear(v_proj, v_bias)
- self.o_proj = wrap_linear(o_proj, o_bias)
-
- # QK Norm (Qwen3 style)
- self.q_norm = q_norm
- self.k_norm = k_norm
-
- self.config = config
- self.head_dim = config.head_dim
- self.num_heads = config.num_heads
- assert config.num_kv_heads is not None # Set in __post_init__
- self.num_kv_heads: int = config.num_kv_heads
- self.num_kv_groups = config.num_kv_groups
-
- # Store dimensions for QKV split
- self.q_dim = self.num_heads * self.head_dim
- self.k_dim = self.num_kv_heads * self.head_dim
- self.v_dim = self.num_kv_heads * self.head_dim
-
- # Create fused QKV projection (reduces 3 matmuls to 1)
- # Skip fusion for FP8 (LinearFP8 can't be concatenated)
- self.qkv_proj: LinearBF16 | None = None
- if not isinstance(self.q_proj, LinearFP8):
- # Extract weights from LinearBF16 for concatenation
- q_weight = self.q_proj.weight if isinstance(self.q_proj, LinearBF16) else q_proj
- k_weight = self.k_proj.weight if isinstance(self.k_proj, LinearBF16) else k_proj
- v_weight = self.v_proj.weight if isinstance(self.v_proj, LinearBF16) else v_proj
- qkv_weight = concat_axis0(concat_axis0(q_weight, k_weight), v_weight)
- self.qkv_proj = LinearBF16(qkv_weight, None)
-
- # Precompute RoPE if enabled
- self._cos: np.ndarray | None
- self._sin: np.ndarray | None
- if config.use_rope:
- self._cos, self._sin = precompute_freqs_cis(
- self.head_dim, config.max_position_embeddings, config.rope_theta
- )
- else:
- self._cos, self._sin = None, None
-
- # Fixed-length KV cache for CUDA Graph (initialized on first use)
- self._k_cache: GPUArray | None = None
- self._v_cache: GPUArray | None = None
- self._max_cache_len: int = 0
-
- # Lookahead KV tracking for Jacobi decoding
- self._confirmed_pos: int = 0
- self._logical_pos: int = 0
-
- def init_fixed_cache(self, max_seq_len: int, dtype: str = "float16") -> None:
- """Initialize fixed-length KV cache for CUDA Graph capture.
-
- Args:
- max_seq_len: Maximum sequence length to support.
- dtype: Data type for cache (float16/bfloat16/float32).
- """
- cache_shape = (self.num_heads, max_seq_len, self.head_dim)
- if dtype == "float16":
- np_dtype = np.float16
- elif dtype == "bfloat16":
- np_dtype = np.uint16 # bf16 stored as uint16
- else:
- np_dtype = np.float32
- self._k_cache = from_numpy(np.zeros(cache_shape, dtype=np_dtype))
- self._v_cache = from_numpy(np.zeros(cache_shape, dtype=np_dtype))
- self._max_cache_len = max_seq_len
- self._confirmed_pos = 0
- self._logical_pos = 0
-
- # =========================================================================
- # Lookahead KV Cache Management (for Jacobi Decoding)
- # =========================================================================
-
- def set_confirmed_pos(self, pos: int) -> None:
- """Set the confirmed position (e.g., after prefill)."""
- assert 0 <= pos <= self._max_cache_len, f"Invalid pos {pos}"
- self._confirmed_pos = pos
- self._logical_pos = pos
-
- def reset_lookahead(self) -> None:
- """Reset lookahead pointer to confirmed position."""
- self._logical_pos = self._confirmed_pos
-
- def commit_lookahead(self, n_accepted: int) -> None:
- """Commit accepted tokens by advancing confirmed_pos."""
- new_pos = self._confirmed_pos + n_accepted
- assert new_pos <= self._max_cache_len, f"Commit exceeds cache: {new_pos}"
- self._confirmed_pos = new_pos
- self._logical_pos = new_pos
-
- def get_confirmed_pos(self) -> int:
- """Get current confirmed position."""
- return self._confirmed_pos
-
- def __call__(
- self,
- x: GPUArray,
- position_ids: list[int] | None = None,
- past_kv: tuple | None = None,
- use_cache: bool = False,
- ) -> tuple[GPUArray, tuple | None]:
- """Forward pass with hybrid CPU/GPU attention.
-
- Args:
- x: Input tensor [seq_len, hidden_size]
- position_ids: Position IDs for RoPE (auto-generated if None)
- past_kv: Tuple of (past_k, past_v) numpy arrays
- use_cache: Whether to return KV cache
-
- Returns:
- Tuple of (output, present_kv)
- """
- seq_len = x.shape[0]
-
- if position_ids is None:
- position_ids = list(range(seq_len))
-
- return self._forward_gpu(x, position_ids, past_kv, use_cache)
-
- def _forward_gpu(
- self,
- x: GPUArray,
- position_ids: list[int],
- past_kv: tuple | None,
- use_cache: bool,
- ) -> tuple[GPUArray, tuple | None]:
- """GPU path for long sequences (prefill)."""
- seq_len = x.shape[0]
-
- # Project Q, K, V
- q = self.q_proj(x)
- k = self.k_proj(x)
- v = self.v_proj(x)
-
- # Reshape for multi-head
- q = reshape_copy(q, (seq_len, self.num_heads, self.head_dim))
- k = reshape_copy(k, (seq_len, self.num_kv_heads, self.head_dim))
- v = reshape_copy(v, (seq_len, self.num_kv_heads, self.head_dim))
-
- # QK Norm (Qwen3 style)
- if self.q_norm is not None:
- q_shape = (seq_len, self.num_heads, self.head_dim)
- q_2d = reshape_copy(q, (seq_len * self.num_heads, self.head_dim))
- q_2d = self.q_norm(q_2d)
- q = reshape_copy(q_2d, q_shape)
- if self.k_norm is not None:
- k_shape = (seq_len, self.num_kv_heads, self.head_dim)
- k_2d = reshape_copy(k, (seq_len * self.num_kv_heads, self.head_dim))
- k_2d = self.k_norm(k_2d)
- k = reshape_copy(k_2d, k_shape)
-
- # Apply RoPE on GPU
- if self.config.use_rope:
- assert self._cos is not None and self._sin is not None
- from pygpukit.ops.basic import rope_inplace_f32table
-
- q_dtype = q.dtype
- cos_f32 = from_numpy(self._cos[position_ids].astype(np.float32))
- sin_f32 = from_numpy(self._sin[position_ids].astype(np.float32))
- if q_dtype in (dt_float16, dt_bfloat16):
- # Use f32 tables directly for higher precision (no intermediate alloc)
- rope_inplace_f32table(q, k, cos_f32, sin_f32)
- else:
- rope_inplace(q, k, cos_f32, sin_f32)
-
- # GPU KV Cache
- if past_kv is not None:
- past_k, past_v = past_kv
- if isinstance(past_k, GPUArray):
- k = concat_axis0(past_k, k)
- v = concat_axis0(past_v, v)
- else:
- k_np = k.to_numpy()
- v_np = v.to_numpy()
- k_np = np.concatenate([past_k, k_np], axis=0)
- v_np = np.concatenate([past_v, v_np], axis=0)
- k = from_numpy(k_np)
- v = from_numpy(v_np)
-
- present_kv = (k, v) if use_cache else None
-
- # Expand for GQA on GPU
- if self.num_kv_groups > 1:
- k_expanded = repeat_interleave_axis1(k, self.num_kv_groups)
- v_expanded = repeat_interleave_axis1(v, self.num_kv_groups)
- else:
- k_expanded = k
- v_expanded = v
-
- # GPU SDPA
- q_t = transpose_3d_021(q)
- k_t = transpose_3d_021(k_expanded)
- v_t = transpose_3d_021(v_expanded)
-
- attn_output = sdpa_causal(q_t, k_t, v_t)
- attn_output = transpose_3d_021(attn_output)
- attn_output = reshape_copy(attn_output, (seq_len, self.num_heads * self.head_dim))
-
- return self.o_proj(attn_output), present_kv
-
- def forward_fixed_cache(
- self,
- x: GPUArray,
- position: int,
- context_len: int,
- *,
- out: GPUArray | None = None,
- ) -> GPUArray:
- """Forward pass using fixed-length KV cache (for CUDA Graph decode).
-
- Args:
- x: Input tensor [1, hidden_size] - single token
- position: Current position in sequence (for RoPE and cache update)
- context_len: Total context length (prefill + decoded so far)
- out: Optional pre-allocated output buffer
-
- Returns:
- Output tensor [1, hidden_size]
- """
- assert self._k_cache is not None, "Call init_fixed_cache first"
- assert x.shape[0] == 1, "forward_fixed_cache expects single token"
-
- if self.qkv_proj is not None:
- # Fused QKV projection (faster for non-FP8)
- qkv = self.qkv_proj(x)
- q_2d = qkv.narrow(0, self.q_dim)
- k_2d = qkv.narrow(self.q_dim, self.k_dim)
- v_2d = qkv.narrow(self.q_dim + self.k_dim, self.v_dim)
-
- # Apply biases separately
- if self.q_proj.bias is not None:
- bias_add_inplace(q_2d, self.q_proj.bias)
- if self.k_proj.bias is not None:
- bias_add_inplace(k_2d, self.k_proj.bias)
- if self.v_proj.bias is not None:
- bias_add_inplace(v_2d, self.v_proj.bias)
- else:
- # Separate projections (for FP8)
- q_2d = self.q_proj(x)
- k_2d = self.k_proj(x)
- v_2d = self.v_proj(x)
-
- # Zero-copy reshape
- q = q_2d.view((1, self.num_heads, self.head_dim))
- k = k_2d.view((1, self.num_kv_heads, self.head_dim))
- v = v_2d.view((1, self.num_kv_heads, self.head_dim))
-
- # QK Norm
- if self.q_norm is not None:
- q_flat = q.view((self.num_heads, self.head_dim))
- q_normed = self.q_norm(q_flat)
- q = q_normed.view((1, self.num_heads, self.head_dim))
- if self.k_norm is not None:
- k_flat = k.view((self.num_kv_heads, self.head_dim))
- k_normed = self.k_norm(k_flat)
- k = k_normed.view((1, self.num_kv_heads, self.head_dim))
-
- q_dtype = q.dtype
-
- # Apply RoPE
- if self.config.use_rope and self._cos is not None and self._sin is not None:
- from pygpukit.ops.basic import rope_inplace_f32table
-
- cos_f32 = from_numpy(self._cos[position : position + 1].astype(np.float32))
- sin_f32 = from_numpy(self._sin[position : position + 1].astype(np.float32))
- if q_dtype in (dt_float16, dt_bfloat16):
- rope_inplace_f32table(q, k, cos_f32, sin_f32)
- else:
- rope_inplace(q, k, cos_f32, sin_f32)
-
- # Update KV cache
- kv_cache_update_gqa(k, self._k_cache, self.num_heads, position)
- kv_cache_update_gqa(v, self._v_cache, self.num_heads, position)
-
- q_t = q.view((self.num_heads, 1, self.head_dim))
-
- # Allocate output buffer if needed
- if out is None:
- if q_dtype == dt_float16:
- out_np_dtype = np.float16
- elif q_dtype == dt_bfloat16:
- out_np_dtype = np.uint16
- else:
- out_np_dtype = np.float32
- attn_out = from_numpy(np.zeros((self.num_heads, 1, self.head_dim), dtype=out_np_dtype))
- else:
- attn_out = out
-
- sdpa_causal_fixed_cache(q_t, self._k_cache, self._v_cache, attn_out, context_len)
-
- attn_output = attn_out.view((1, self.num_heads * self.head_dim))
- return self.o_proj(attn_output)
-
- def forward_fixed_cache_batch(
- self,
- x: GPUArray,
- start_position: int,
- context_len: int,
- ) -> GPUArray:
- """Forward pass for batch decode using fixed-length KV cache.
-
- Processes multiple tokens at once for speculative decoding verification.
- """
- assert self._k_cache is not None, "Call init_fixed_cache first"
- seq_len = x.shape[0]
-
- if seq_len == 1:
- return self.forward_fixed_cache(x, start_position, context_len)
-
- if self.qkv_proj is not None:
- # Fused QKV projection (faster for non-FP8)
- qkv = self.qkv_proj(x)
- qkv_np = qkv.to_numpy()
- q_np = qkv_np[:, : self.q_dim]
- k_np = qkv_np[:, self.q_dim : self.q_dim + self.k_dim]
- v_np = qkv_np[:, self.q_dim + self.k_dim :]
-
- # Apply biases
- if self.q_proj.bias is not None:
- q_np = q_np + self.q_proj.bias.to_numpy()
- if self.k_proj.bias is not None:
- k_np = k_np + self.k_proj.bias.to_numpy()
- if self.v_proj.bias is not None:
- v_np = v_np + self.v_proj.bias.to_numpy()
-
- q_2d = from_numpy(q_np.astype(qkv_np.dtype))
- k_2d = from_numpy(k_np.astype(qkv_np.dtype))
- v_2d = from_numpy(v_np.astype(qkv_np.dtype))
- else:
- # Separate projections (for FP8)
- q_2d = self.q_proj(x)
- k_2d = self.k_proj(x)
- v_2d = self.v_proj(x)
-
- q = reshape_copy(q_2d, (seq_len, self.num_heads, self.head_dim))
- k = reshape_copy(k_2d, (seq_len, self.num_kv_heads, self.head_dim))
- v = reshape_copy(v_2d, (seq_len, self.num_kv_heads, self.head_dim))
-
- # QK Norm
- if self.q_norm is not None:
- q_flat = reshape_copy(q, (seq_len * self.num_heads, self.head_dim))
- q_normed = self.q_norm(q_flat)
- q = reshape_copy(q_normed, (seq_len, self.num_heads, self.head_dim))
- if self.k_norm is not None:
- k_flat = reshape_copy(k, (seq_len * self.num_kv_heads, self.head_dim))
- k_normed = self.k_norm(k_flat)
- k = reshape_copy(k_normed, (seq_len, self.num_kv_heads, self.head_dim))
-
- q_dtype = q.dtype
-
- # RoPE
- if self.config.use_rope and self._cos is not None and self._sin is not None:
- from pygpukit.ops.basic import rope_inplace_f32table
-
- end_pos = start_position + seq_len
- cos_f32 = from_numpy(self._cos[start_position:end_pos].astype(np.float32))
- sin_f32 = from_numpy(self._sin[start_position:end_pos].astype(np.float32))
- if q_dtype in (dt_float16, dt_bfloat16):
- rope_inplace_f32table(q, k, cos_f32, sin_f32)
- else:
- rope_inplace(q, k, cos_f32, sin_f32)
-
- # Update KV cache
- kv_cache_prefill_gqa(k, self._k_cache, self.num_heads, start_position)
- kv_cache_prefill_gqa(v, self._v_cache, self.num_heads, start_position)
-
- q_t = transpose_3d_021(q)
- # Allocate attn_out with matching dtype
- if q_dtype == dt_float16:
- out_np_dtype = np.float16
- elif q_dtype == dt_bfloat16:
- out_np_dtype = np.uint16 # bfloat16 stored as uint16
- else:
- out_np_dtype = np.float32
- attn_out = from_numpy(
- np.zeros((self.num_heads, seq_len, self.head_dim), dtype=out_np_dtype)
- )
-
- sdpa_causal_fixed_cache(q_t, self._k_cache, self._v_cache, attn_out, context_len)
-
- attn_output = transpose_3d_021(attn_out)
- attn_output = reshape_copy(attn_output, (seq_len, self.num_heads * self.head_dim))
- return self.o_proj(attn_output)
-
- def forward_fixed_cache_batch_zero_alloc(
- self,
- x: GPUArray,
- start_position: int,
- context_len: int,
- buffers: DecodeBuffers,
- rope_cos_gpu: GPUArray | None,
- rope_sin_gpu: GPUArray | None,
- start_pos_buf: GPUArray,
- ) -> GPUArray:
- """Zero-allocation forward pass for batch decode using fixed-length KV cache.
-
- This version uses pre-allocated buffers for all operations, making it
- compatible with CUDA Graph capture. No memory allocations occur.
- """
- assert self._k_cache is not None, "Call init_fixed_cache first"
- seq_len = x.shape[0]
-
- q_out = buffers.q_batch.view((seq_len, self.num_heads, self.head_dim))
- k_out = buffers.k_batch.view((seq_len, self.num_kv_heads, self.head_dim))
- v_out = buffers.v_batch.view((seq_len, self.num_kv_heads, self.head_dim))
-
- if self.qkv_proj is not None:
- # Fused QKV projection into pre-allocated buffer
- qkv_out = buffers.qkv_proj_out_batch.slice_rows(seq_len)
- self.qkv_proj(x, out=qkv_out)
-
- # Split QKV
- split_qkv_batch(qkv_out, q_out, k_out, v_out, self.q_dim, self.k_dim, self.v_dim)
-
- # Apply biases
- if self.q_proj.bias is not None:
- q_out_2d = q_out.view((seq_len, self.q_dim))
- bias_add_inplace(q_out_2d, self.q_proj.bias)
- if self.k_proj.bias is not None:
- k_out_2d = k_out.view((seq_len, self.k_dim))
- bias_add_inplace(k_out_2d, self.k_proj.bias)
- if self.v_proj.bias is not None:
- v_out_2d = v_out.view((seq_len, self.v_dim))
- bias_add_inplace(v_out_2d, self.v_proj.bias)
- else:
- # Separate projections (for FP8 - allocates, not zero-alloc)
- q_2d = self.q_proj(x)
- k_2d = self.k_proj(x)
- v_2d = self.v_proj(x)
- copy_to(reshape_copy(q_2d, (seq_len, self.num_heads, self.head_dim)), q_out)
- copy_to(reshape_copy(k_2d, (seq_len, self.num_kv_heads, self.head_dim)), k_out)
- copy_to(reshape_copy(v_2d, (seq_len, self.num_kv_heads, self.head_dim)), v_out)
-
- # QK Norm
- if self.q_norm is not None and buffers.q_flat_batch is not None:
- q_flat = buffers.q_flat_batch.slice_rows(seq_len * self.num_heads)
- copy_to(q_out.view((seq_len * self.num_heads, self.head_dim)), q_flat)
- rmsnorm(q_flat, self.q_norm.weight, self.q_norm.eps, out=q_flat)
- copy_to(q_flat.view((seq_len, self.num_heads, self.head_dim)), q_out)
-
- if self.k_norm is not None and buffers.k_flat_batch is not None:
- k_flat = buffers.k_flat_batch.slice_rows(seq_len * self.num_kv_heads)
- copy_to(k_out.view((seq_len * self.num_kv_heads, self.head_dim)), k_flat)
- rmsnorm(k_flat, self.k_norm.weight, self.k_norm.eps, out=k_flat)
- copy_to(k_flat.view((seq_len, self.num_kv_heads, self.head_dim)), k_out)
-
- # RoPE
- if self.config.use_rope and rope_cos_gpu is not None and rope_sin_gpu is not None:
- cos_out = buffers.cos_batch.slice_rows(seq_len)
- sin_out = buffers.sin_batch.slice_rows(seq_len)
- slice_rows_range_ptr(rope_cos_gpu, cos_out, start_pos_buf, seq_len)
- slice_rows_range_ptr(rope_sin_gpu, sin_out, start_pos_buf, seq_len)
- rope_inplace(q_out, k_out, cos_out, sin_out)
-
- # Update KV cache
- kv_cache_prefill_gqa(k_out, self._k_cache, self.num_heads, start_position)
- kv_cache_prefill_gqa(v_out, self._v_cache, self.num_heads, start_position)
-
- # Transpose Q for SDPA
- q_t_out = buffers.q_t_batch.view((self.num_heads, seq_len, self.head_dim))
- transpose_3d_021(q_out, out=q_t_out)
-
- # SDPA
- attn_out = buffers.attn_out_batch.view((self.num_heads, seq_len, self.head_dim))
- sdpa_causal_fixed_cache(q_t_out, self._k_cache, self._v_cache, attn_out, context_len)
-
- # Transpose output
- attn_out_t = buffers.attn_out_t_batch.view((seq_len, self.num_heads, self.head_dim))
- transpose_3d_021(attn_out, out=attn_out_t)
-
- attn_out_2d = attn_out_t.view((seq_len, self.num_heads * self.head_dim))
-
- # O projection
- o_out = buffers.o_proj_out_batch.slice_rows(seq_len)
- self.o_proj(attn_out_2d, out=o_out)
-
- return o_out
-
-
-# =============================================================================
-# Unified MLP
-# =============================================================================
-
-
-class MLP:
- """Unified MLP supporting GELU and SwiGLU activations.
-
- GELU (GPT-2 style):
- fc1 -> GELU -> fc2
-
- SwiGLU (LLaMA style):
- gate_proj -> SiLU -> * up_proj -> down_proj
-
- Supports FP8 quantized weights via LinearFP8.
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- # GELU path weights (GPUArray or LinearBF16/LinearFP8)
- fc1_weight: GPUArray | LinearBF16 | LinearFP8 | None = None,
- fc1_bias: GPUArray | None = None,
- fc2_weight: GPUArray | LinearBF16 | LinearFP8 | None = None,
- fc2_bias: GPUArray | None = None,
- # SwiGLU path weights (GPUArray or LinearBF16/LinearFP8)
- gate_proj: GPUArray | LinearBF16 | LinearFP8 | None = None,
- up_proj: GPUArray | LinearBF16 | LinearFP8 | None = None,
- down_proj: GPUArray | LinearBF16 | LinearFP8 | None = None,
- ):
- self.config = config
- self.activation = config.activation
-
- # Helper to wrap GPUArray in LinearBF16, or use pre-built LinearBF16/LinearFP8
- def wrap_linear(
- proj: GPUArray | LinearBF16 | LinearFP8 | None, bias: GPUArray | None = None
- ) -> LinearBF16 | LinearFP8 | None:
- if proj is None:
- return None
- if isinstance(proj, (LinearBF16, LinearFP8)):
- return proj
- return LinearBF16(proj, bias)
-
- if config.activation == "gelu":
- if fc1_weight is None or fc2_weight is None:
- raise ValueError("GELU MLP requires fc1_weight and fc2_weight")
- self.fc1 = wrap_linear(fc1_weight, fc1_bias)
- self.fc2 = wrap_linear(fc2_weight, fc2_bias)
- else: # silu (SwiGLU)
- if gate_proj is None or up_proj is None or down_proj is None:
- raise ValueError("SwiGLU MLP requires gate_proj, up_proj, down_proj")
-
- self.gate_proj = wrap_linear(gate_proj)
- self.up_proj = wrap_linear(up_proj)
- self.down_proj = wrap_linear(down_proj)
-
- # Get intermediate size from the projection
- if isinstance(gate_proj, (LinearBF16, LinearFP8)):
- self.intermediate_size = gate_proj.out_features
- else:
- self.intermediate_size = gate_proj.shape[0]
-
- # Fused gate_up projection only for non-FP8 (GPUArray) weights
- # FP8 weights can't be concatenated trivially
- if isinstance(gate_proj, GPUArray) and isinstance(up_proj, GPUArray):
- gate_up_weight = concat_axis0(gate_proj, up_proj)
- self.gate_up_proj: LinearBF16 | None = LinearBF16(gate_up_weight, None)
- else:
- self.gate_up_proj = None
-
- def __call__(self, x: GPUArray) -> GPUArray:
- if self.activation == "gelu":
- h = self.fc1(x)
- h = gelu(h)
- return self.fc2(h)
- else:
- gate = silu(self.gate_proj(x))
- up = self.up_proj(x)
- return self.down_proj(mul(gate, up))
-
-
-# =============================================================================
-# Mixture of Experts Layer
-# =============================================================================
-
-
-class MoELayer:
- """Mixture of Experts layer for Mixtral-style models.
-
- Architecture:
- 1. Router: hidden -> [num_experts] logits
- 2. Top-K selection with softmax
- 3. Expert FFN (SwiGLU) for each selected expert
- 4. Weighted combination of expert outputs
-
- Supports FP8 quantized expert weights via LinearFP8.
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- gate_weight: GPUArray, # [num_experts, hidden_size] - router
- expert_weights: list, # [(gate, up, down), ...] - GPUArray or LinearBF16/LinearFP8
- ):
- self.config = config
- self.num_experts = config.num_experts or len(expert_weights)
- self.num_experts_per_tok = config.num_experts_per_tok
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.moe_intermediate_size or config.intermediate_size
-
- # Router (gate) projection
- self.gate = LinearBF16(gate_weight)
-
- # Expert FFNs
- self.experts: list[MLP] = []
- for gate_proj, up_proj, down_proj in expert_weights:
- expert = MLP(
- config,
- gate_proj=gate_proj,
- up_proj=up_proj,
- down_proj=down_proj,
- )
- self.experts.append(expert)
-
- # Check if all experts use FP8 weights for grouped GEMM optimization
- self._use_grouped_gemm = False
- self._stacked_gate_weight: GPUArray | None = None
- self._stacked_gate_scale: GPUArray | None = None
- self._stacked_up_weight: GPUArray | None = None
- self._stacked_up_scale: GPUArray | None = None
- self._stacked_down_weight: GPUArray | None = None
- self._stacked_down_scale: GPUArray | None = None
-
- # Check if first expert uses FP8 - use grouped GEMM v2 for optimization
- # TEMP: Disabled for debugging
- import os
-
- if os.environ.get("PYGPUKIT_DISABLE_GROUPED_GEMM") != "1":
- if len(self.experts) > 0 and isinstance(self.experts[0].gate_proj, LinearFP8):
- self._stack_fp8_weights()
-
- # Profiling flag (set to True to enable timing)
- _profile: bool = True
- _profile_count: int = 0
-
- def _stack_fp8_weights(self) -> None:
- """Stack FP8 expert weights for grouped GEMM optimization."""
- # Collect weights from all experts
- gate_weights = []
- gate_scales = []
- up_weights = []
- up_scales = []
- down_weights = []
- down_scales = []
-
- for expert in self.experts:
- if not isinstance(expert.gate_proj, LinearFP8):
- return # Not all experts are FP8, abort
-
- gate_weights.append(expert.gate_proj.weight_fp8)
- gate_scales.append(expert.gate_proj.scale_inv)
- up_weights.append(expert.up_proj.weight_fp8)
- up_scales.append(expert.up_proj.scale_inv)
- down_weights.append(expert.down_proj.weight_fp8)
- down_scales.append(expert.down_proj.scale_inv)
-
- # Stack weights: [num_experts, N, K]
- # gate_proj: [intermediate_size, hidden_size] -> stacked [num_experts, intermediate_size, hidden_size]
- # Each weight is [N, K], stack along new axis 0
-
- def stack_arrays_fast(arrays: list[GPUArray]) -> GPUArray:
- """Stack arrays along new axis 0 using single allocation + cudaMemcpy."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- # Get shape info from first array
- first = arrays[0]
- num_arrays = len(arrays)
- inner_shape = first.shape # [N, K] or [N/128, K/128]
-
- # Calculate strides (nbytes is property, not method)
- bytes_per_array = first._get_native().nbytes
-
- # Allocate output: [num_arrays, *inner_shape]
- out_shape = [num_arrays] + list(inner_shape)
- out_native = native.empty(out_shape, first._get_native().dtype)
- out = GPUArray._wrap_native(out_native)
-
- # Copy each array to its slice using cuMemcpy
- for i, arr in enumerate(arrays):
- offset_bytes = i * bytes_per_array
- native.memcpy_device_to_device_offset(
- arr._get_native(),
- out._get_native(),
- 0, # src offset
- offset_bytes, # dst offset
- bytes_per_array,
- )
-
- return out
-
- self._stacked_gate_weight = stack_arrays_fast(gate_weights)
- self._stacked_gate_scale = stack_arrays_fast(gate_scales)
- self._stacked_up_weight = stack_arrays_fast(up_weights)
- self._stacked_up_scale = stack_arrays_fast(up_scales)
- self._stacked_down_weight = stack_arrays_fast(down_weights)
- self._stacked_down_scale = stack_arrays_fast(down_scales)
-
- self._use_grouped_gemm = True
- print(f"[MoE] Stacked {self.num_experts} expert weights for grouped GEMM")
-
- def __call__(self, x: GPUArray) -> GPUArray:
- """Forward pass through MoE layer.
-
- Args:
- x: Input tensor [batch, seq, hidden_size] or [seq, hidden_size]
-
- Returns:
- Output tensor with same shape as input
- """
- import time
-
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- profile = self._profile and MoELayer._profile_count < 3
- if profile:
- native.device_synchronize()
- t0 = time.perf_counter()
-
- original_shape = x.shape
- # Flatten to [num_tokens, hidden_size]
- if len(original_shape) == 3:
- batch, seq, hidden = original_shape
- num_tokens = batch * seq
- x = x.reshape(num_tokens, hidden)
- else:
- num_tokens, hidden = original_shape
-
- k = self.num_experts_per_tok
-
- # Step 1: Compute router logits
- router_logits = self.gate(x) # [num_tokens, num_experts]
- if profile:
- native.device_synchronize()
- t1 = time.perf_counter()
-
- # Step 2: Top-K selection
- router_weights = zeros((num_tokens, k), dtype=x.dtype)
- expert_indices = zeros((num_tokens, k), dtype="int32")
- native.moe_topk_with_indices(
- router_logits._get_native(),
- router_weights._get_native(),
- expert_indices._get_native(),
- k,
- )
-
- # Step 3: Softmax over selected experts
- native.moe_softmax_topk(router_weights._get_native(), k)
-
- # Step 4: Compute permutation for efficient expert dispatch
- expert_counts = zeros((self.num_experts,), dtype="int32")
- expert_offsets = zeros((self.num_experts + 1,), dtype="int32")
- permute_indices = zeros((num_tokens * k,), dtype="int32")
- reverse_perm = zeros((num_tokens * k,), dtype="int32")
- native.moe_compute_permutation(
- expert_indices._get_native(),
- expert_counts._get_native(),
- expert_offsets._get_native(),
- permute_indices._get_native(),
- reverse_perm._get_native(),
- self.num_experts,
- k,
- )
-
- # Step 5: Gather hidden states for experts
- gathered = zeros((num_tokens * k, hidden), dtype=x.dtype)
- native.moe_gather(
- x._get_native(),
- permute_indices._get_native(),
- gathered._get_native(),
- k,
- )
- if profile:
- native.device_synchronize()
- t2 = time.perf_counter()
-
- # Step 6: Run experts
- if self._use_grouped_gemm:
- # Use grouped GEMM for all experts in single kernel launches
- from pygpukit.ops.matmul import grouped_gemm_fp8_bf16
-
- # Create row_expert_ids from expert_offsets
- M_total = num_tokens * k
- row_expert_ids = zeros((M_total,), dtype="int32")
- native.moe_expand_expert_offsets(
- expert_offsets._get_native(),
- row_expert_ids._get_native(),
- self.num_experts,
- )
-
- # gate_proj: gathered[M_total, hidden] @ gate_weight[experts, inter, hidden]^T
- gate_out = grouped_gemm_fp8_bf16(
- gathered,
- self._stacked_gate_weight,
- self._stacked_gate_scale,
- row_expert_ids,
- )
-
- # up_proj: gathered[M_total, hidden] @ up_weight[experts, inter, hidden]^T
- up_out = grouped_gemm_fp8_bf16(
- gathered,
- self._stacked_up_weight,
- self._stacked_up_scale,
- row_expert_ids,
- )
-
- # SiLU(gate) * up
- intermediate = mul(silu(gate_out), up_out)
-
- # down_proj: intermediate[M_total, inter] @ down_weight[experts, hidden, inter]^T
- expert_outputs = grouped_gemm_fp8_bf16(
- intermediate,
- self._stacked_down_weight,
- self._stacked_down_scale,
- row_expert_ids,
- )
- else:
- # Fallback: Run experts sequentially
- # Get expert counts on CPU for loop
- expert_counts_cpu = expert_counts.to_numpy()
- expert_offsets_cpu = expert_offsets.to_numpy()
-
- # Build list of (expert_id, start, count) for non-empty experts
- expert_tasks = []
- for e in range(self.num_experts):
- start = int(expert_offsets_cpu[e])
- count = int(expert_counts_cpu[e])
- if count > 0:
- expert_tasks.append((e, start, count))
-
- def run_expert(task: tuple) -> GPUArray:
- e, start, count = task
- expert_input = gathered[start : start + count]
- return self.experts[e](expert_input)
-
- # Run experts sequentially
- expert_output_list = [run_expert(task) for task in expert_tasks]
-
- # Concatenate all expert outputs on GPU
- from functools import reduce
-
- expert_outputs = reduce(concat_axis0, expert_output_list)
-
- if profile:
- native.device_synchronize()
- t3 = time.perf_counter()
-
- # Step 7: Scatter and combine outputs
- output = zeros((num_tokens, hidden), dtype=x.dtype)
- native.moe_scatter(
- expert_outputs._get_native(),
- router_weights._get_native(),
- reverse_perm._get_native(),
- output._get_native(),
- k,
- )
- if profile:
- native.device_synchronize()
- t4 = time.perf_counter()
- MoELayer._profile_count += 1
- print(
- f"[MoE Profile] router={t1 - t0:.3f}s, routing={t2 - t1:.3f}s, experts={t3 - t2:.3f}s, scatter={t4 - t3:.3f}s"
- )
-
- # Reshape back
- if len(original_shape) == 3:
- output = output.reshape(*original_shape)
-
- return output
-
- def forward_zero_alloc(
- self,
- x: GPUArray,
- router_logits: GPUArray,
- router_weights: GPUArray,
- expert_indices: GPUArray,
- expert_counts: GPUArray,
- expert_offsets: GPUArray,
- permute_indices: GPUArray,
- reverse_perm: GPUArray,
- row_expert_ids: GPUArray,
- gathered: GPUArray,
- gate_out: GPUArray,
- up_out: GPUArray,
- intermediate: GPUArray,
- expert_outputs: GPUArray,
- output: GPUArray,
- ) -> GPUArray:
- """Zero-allocation forward pass for CUDA Graph support.
-
- This method uses pre-allocated buffers from DecodeBuffers to avoid
- any memory allocations during forward pass, enabling CUDA Graph capture.
-
- Args:
- x: Input tensor [1, hidden_size]
- router_logits: Pre-allocated [1, num_experts]
- router_weights: Pre-allocated [1, k]
- expert_indices: Pre-allocated [1, k] int32
- expert_counts: Pre-allocated [num_experts] int32
- expert_offsets: Pre-allocated [num_experts + 1] int32
- permute_indices: Pre-allocated [k] int32
- reverse_perm: Pre-allocated [k] int32
- row_expert_ids: Pre-allocated [k] int32
- gathered: Pre-allocated [k, hidden_size]
- gate_out: Pre-allocated [k, moe_intermediate_size]
- up_out: Pre-allocated [k, moe_intermediate_size]
- intermediate: Pre-allocated [k, moe_intermediate_size]
- expert_outputs: Pre-allocated [k, hidden_size]
- output: Pre-allocated [1, hidden_size]
-
- Returns:
- The output tensor (same as output parameter)
- """
- from pygpukit.core.backend import get_native_module
- from pygpukit.ops.elementwise import mul
- from pygpukit.ops.matmul import grouped_gemm_fp8_bf16
- from pygpukit.ops.nn import silu
-
- native = get_native_module()
-
- k = self.num_experts_per_tok
-
- # Step 1: Router forward (gate projection)
- self.gate(x, out=router_logits)
-
- # Step 2: Top-K selection (writes to router_weights and expert_indices)
- native.moe_topk_with_indices(
- router_logits._get_native(),
- router_weights._get_native(),
- expert_indices._get_native(),
- k,
- )
-
- # Step 3: Softmax over selected experts (in-place)
- native.moe_softmax_topk(router_weights._get_native(), k)
-
- # Step 4: Compute permutation
- native.moe_compute_permutation(
- expert_indices._get_native(),
- expert_counts._get_native(),
- expert_offsets._get_native(),
- permute_indices._get_native(),
- reverse_perm._get_native(),
- self.num_experts,
- k,
- )
-
- # Step 5: Gather hidden states
- native.moe_gather(
- x._get_native(),
- permute_indices._get_native(),
- gathered._get_native(),
- k,
- )
-
- # Step 6: Create row_expert_ids for grouped GEMM
- native.moe_expand_expert_offsets(
- expert_offsets._get_native(),
- row_expert_ids._get_native(),
- self.num_experts,
- )
-
- # Step 7: Expert computation with grouped GEMM
- # gate_proj: gathered[k, hidden] @ gate_weight[experts, inter, hidden]^T
- grouped_gemm_fp8_bf16(
- gathered,
- self._stacked_gate_weight,
- self._stacked_gate_scale,
- row_expert_ids,
- out=gate_out,
- )
-
- # up_proj: gathered[k, hidden] @ up_weight[experts, inter, hidden]^T
- grouped_gemm_fp8_bf16(
- gathered,
- self._stacked_up_weight,
- self._stacked_up_scale,
- row_expert_ids,
- out=up_out,
- )
-
- # SiLU(gate) * up -> intermediate
- silu(gate_out, out=intermediate)
- mul(intermediate, up_out, out=intermediate)
-
- # down_proj: intermediate[k, inter] @ down_weight[experts, hidden, inter]^T
- grouped_gemm_fp8_bf16(
- intermediate,
- self._stacked_down_weight,
- self._stacked_down_scale,
- row_expert_ids,
- out=expert_outputs,
- )
-
- # Step 8: Scatter and combine outputs
- native.moe_scatter(
- expert_outputs._get_native(),
- router_weights._get_native(),
- reverse_perm._get_native(),
- output._get_native(),
- k,
- )
-
- return output
-
-
-# =============================================================================
-# Unified TransformerBlock
-# =============================================================================
-
-
-class TransformerBlock:
- """Unified transformer block.
-
- Structure:
- Norm -> Attention -> Residual
- Norm -> MLP/MoE -> Residual
- """
-
- def __init__(
- self,
- attn_norm: Norm,
- attn: Attention,
- mlp_norm: Norm,
- mlp: MLP | MoELayer,
- ):
- self.attn_norm = attn_norm
- self.attn = attn
- self.mlp_norm = mlp_norm
- self.mlp = mlp # Can be MLP or MoELayer
-
- def __call__(
- self,
- x: GPUArray,
- position_ids: list[int] | None = None,
- past_kv: tuple | None = None,
- use_cache: bool = False,
- ) -> tuple[GPUArray, tuple | None]:
- # Attention block
- residual = x
- x = self.attn_norm(x)
- attn_out, present_kv = self.attn(x, position_ids, past_kv, use_cache)
- x = add(residual, attn_out)
-
- # MLP block
- residual = x
- x = self.mlp_norm(x)
- x = self.mlp(x)
- x = add(residual, x)
-
- return x, present_kv
diff --git a/src/pygpukit/llm/layers/__init__.py b/src/pygpukit/llm/layers/__init__.py
new file mode 100644
index 0000000..91b2c7a
--- /dev/null
+++ b/src/pygpukit/llm/layers/__init__.py
@@ -0,0 +1,74 @@
+"""Neural network layer implementations for PyGPUkit LLM.
+
+Provides:
+- LinearBF16: Dense layer with BF16 weights
+- LinearFP8: Dense layer with FP8 weights (online dequantization)
+- Norm: RMSNorm and LayerNorm
+- Attention: Multi-head attention with RoPE, GQA, QK-Norm, KV cache
+- MLP: Feed-forward network (GELU/SwiGLU)
+- MoELayer: Mixture of Experts
+- TransformerBlock: Attention + MLP with residual connections
+- RoPE utilities: precompute_freqs_cis, apply_rotary_pos_emb_numpy
+- Repack utilities: repack_weight, repack_linear, repack_norm
+"""
+
+from __future__ import annotations
+
+# Attention
+from .attention import Attention
+
+# TransformerBlock
+from .block import TransformerBlock
+
+# Linear layers
+from .linear import (
+ Linear,
+ LinearBF16,
+ LinearFP8,
+)
+
+# MLP
+from .mlp import MLP
+
+# MoE
+from .moe import MoELayer
+
+# Normalization
+from .norm import Norm
+
+# RoPE utilities
+from .rope import (
+ apply_rotary_pos_emb_numpy,
+ precompute_freqs_cis,
+)
+
+# Repack utilities
+from .utils import (
+ repack_linear,
+ repack_norm,
+ repack_weight,
+)
+
+__all__ = [
+ # Linear layers
+ "LinearBF16",
+ "LinearFP8",
+ "Linear",
+ # Normalization
+ "Norm",
+ # RoPE
+ "precompute_freqs_cis",
+ "apply_rotary_pos_emb_numpy",
+ # Attention
+ "Attention",
+ # MLP
+ "MLP",
+ # MoE
+ "MoELayer",
+ # TransformerBlock
+ "TransformerBlock",
+ # Repack utilities
+ "repack_weight",
+ "repack_linear",
+ "repack_norm",
+]
diff --git a/src/pygpukit/llm/layers/attention.py b/src/pygpukit/llm/layers/attention.py
new file mode 100644
index 0000000..40cf9e2
--- /dev/null
+++ b/src/pygpukit/llm/layers/attention.py
@@ -0,0 +1,560 @@
+"""Attention layer implementation for PyGPUkit LLM.
+
+Provides:
+- Attention: Multi-head attention with RoPE, GQA, QK-Norm, KV cache
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import numpy as np
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.dtypes import bfloat16 as dt_bfloat16
+from pygpukit.core.dtypes import float16 as dt_float16
+from pygpukit.core.factory import from_numpy
+from pygpukit.ops.basic import (
+ bias_add_inplace,
+ concat_axis0,
+ copy_to,
+ kv_cache_prefill_gqa,
+ kv_cache_update_gqa,
+ repeat_interleave_axis1,
+ reshape_copy,
+ rmsnorm,
+ rope_inplace,
+ sdpa_causal,
+ sdpa_causal_fixed_cache,
+ slice_rows_range_ptr,
+ split_qkv_batch,
+ transpose_3d_021,
+)
+
+from .linear import LinearBF16, LinearFP8
+from .norm import Norm
+from .rope import precompute_freqs_cis
+
+if TYPE_CHECKING:
+ from pygpukit.llm.buffers import DecodeBuffers
+ from pygpukit.llm.config import TransformerConfig
+
+
+class Attention:
+ """Unified attention with Hybrid CPU/GPU execution.
+
+ Supports:
+ - Multi-Head Attention (MHA): num_kv_heads == num_heads
+ - Grouped Query Attention (GQA): num_kv_heads < num_heads
+ - RoPE: enabled via config.use_rope
+ - QK Norm: optional normalization of Q and K (Qwen3 style)
+ - Hybrid execution: CPU for seq_len=1, GPU for longer sequences
+ - FP8 quantized weights via LinearFP8
+ """
+
+ def __init__(
+ self,
+ q_proj: GPUArray | LinearBF16 | LinearFP8,
+ k_proj: GPUArray | LinearBF16 | LinearFP8,
+ v_proj: GPUArray | LinearBF16 | LinearFP8,
+ o_proj: GPUArray | LinearBF16 | LinearFP8,
+ config: TransformerConfig,
+ q_bias: GPUArray | None = None,
+ k_bias: GPUArray | None = None,
+ v_bias: GPUArray | None = None,
+ o_bias: GPUArray | None = None,
+ q_norm: Norm | None = None,
+ k_norm: Norm | None = None,
+ ):
+ # Accept either GPUArray (wrapped in LinearBF16) or pre-built LinearBF16/LinearFP8
+ def wrap_linear(
+ proj: GPUArray | LinearBF16 | LinearFP8, bias: GPUArray | None
+ ) -> LinearBF16 | LinearFP8:
+ if isinstance(proj, (LinearBF16, LinearFP8)):
+ return proj
+ return LinearBF16(proj, bias)
+
+ self.q_proj = wrap_linear(q_proj, q_bias)
+ self.k_proj = wrap_linear(k_proj, k_bias)
+ self.v_proj = wrap_linear(v_proj, v_bias)
+ self.o_proj = wrap_linear(o_proj, o_bias)
+
+ # QK Norm (Qwen3 style)
+ self.q_norm = q_norm
+ self.k_norm = k_norm
+
+ self.config = config
+ self.head_dim = config.head_dim
+ self.num_heads = config.num_heads
+ assert config.num_kv_heads is not None # Set in __post_init__
+ self.num_kv_heads: int = config.num_kv_heads
+ self.num_kv_groups = config.num_kv_groups
+
+ # Store dimensions for QKV split
+ self.q_dim = self.num_heads * self.head_dim
+ self.k_dim = self.num_kv_heads * self.head_dim
+ self.v_dim = self.num_kv_heads * self.head_dim
+
+ # Create fused QKV projection (reduces 3 matmuls to 1)
+ # Skip fusion for FP8 (LinearFP8 can't be concatenated)
+ self.qkv_proj: LinearBF16 | None = None
+ if not isinstance(self.q_proj, LinearFP8):
+ # Extract weights from LinearBF16 for concatenation
+ q_weight = self.q_proj.weight if isinstance(self.q_proj, LinearBF16) else q_proj
+ k_weight = self.k_proj.weight if isinstance(self.k_proj, LinearBF16) else k_proj
+ v_weight = self.v_proj.weight if isinstance(self.v_proj, LinearBF16) else v_proj
+ qkv_weight = concat_axis0(concat_axis0(q_weight, k_weight), v_weight)
+ self.qkv_proj = LinearBF16(qkv_weight, None)
+
+ # Precompute RoPE if enabled
+ self._cos: np.ndarray | None
+ self._sin: np.ndarray | None
+ if config.use_rope:
+ self._cos, self._sin = precompute_freqs_cis(
+ self.head_dim, config.max_position_embeddings, config.rope_theta
+ )
+ else:
+ self._cos, self._sin = None, None
+
+ # Fixed-length KV cache for CUDA Graph (initialized on first use)
+ self._k_cache: GPUArray | None = None
+ self._v_cache: GPUArray | None = None
+ self._max_cache_len: int = 0
+
+ # Lookahead KV tracking for Jacobi decoding
+ self._confirmed_pos: int = 0
+ self._logical_pos: int = 0
+
+ def init_fixed_cache(self, max_seq_len: int, dtype: str = "float16") -> None:
+ """Initialize fixed-length KV cache for CUDA Graph capture.
+
+ Args:
+ max_seq_len: Maximum sequence length to support.
+ dtype: Data type for cache (float16/bfloat16/float32).
+ """
+ cache_shape = (self.num_heads, max_seq_len, self.head_dim)
+ if dtype == "float16":
+ np_dtype = np.float16
+ elif dtype == "bfloat16":
+ np_dtype = np.uint16 # bf16 stored as uint16
+ else:
+ np_dtype = np.float32
+ self._k_cache = from_numpy(np.zeros(cache_shape, dtype=np_dtype))
+ self._v_cache = from_numpy(np.zeros(cache_shape, dtype=np_dtype))
+ self._max_cache_len = max_seq_len
+ self._confirmed_pos = 0
+ self._logical_pos = 0
+
+ # =========================================================================
+ # Lookahead KV Cache Management (for Jacobi Decoding)
+ # =========================================================================
+
+ def set_confirmed_pos(self, pos: int) -> None:
+ """Set the confirmed position (e.g., after prefill)."""
+ assert 0 <= pos <= self._max_cache_len, f"Invalid pos {pos}"
+ self._confirmed_pos = pos
+ self._logical_pos = pos
+
+ def reset_lookahead(self) -> None:
+ """Reset lookahead pointer to confirmed position."""
+ self._logical_pos = self._confirmed_pos
+
+ def commit_lookahead(self, n_accepted: int) -> None:
+ """Commit accepted tokens by advancing confirmed_pos."""
+ new_pos = self._confirmed_pos + n_accepted
+ assert new_pos <= self._max_cache_len, f"Commit exceeds cache: {new_pos}"
+ self._confirmed_pos = new_pos
+ self._logical_pos = new_pos
+
+ def get_confirmed_pos(self) -> int:
+ """Get current confirmed position."""
+ return self._confirmed_pos
+
+ def __call__(
+ self,
+ x: GPUArray,
+ position_ids: list[int] | None = None,
+ past_kv: tuple | None = None,
+ use_cache: bool = False,
+ ) -> tuple[GPUArray, tuple | None]:
+ """Forward pass with hybrid CPU/GPU attention.
+
+ Args:
+ x: Input tensor [seq_len, hidden_size]
+ position_ids: Position IDs for RoPE (auto-generated if None)
+ past_kv: Tuple of (past_k, past_v) numpy arrays
+ use_cache: Whether to return KV cache
+
+ Returns:
+ Tuple of (output, present_kv)
+ """
+ seq_len = x.shape[0]
+
+ if position_ids is None:
+ position_ids = list(range(seq_len))
+
+ return self._forward_gpu(x, position_ids, past_kv, use_cache)
+
+ def _forward_gpu(
+ self,
+ x: GPUArray,
+ position_ids: list[int],
+ past_kv: tuple | None,
+ use_cache: bool,
+ ) -> tuple[GPUArray, tuple | None]:
+ """GPU path for long sequences (prefill)."""
+ seq_len = x.shape[0]
+
+ # Project Q, K, V
+ q = self.q_proj(x)
+ k = self.k_proj(x)
+ v = self.v_proj(x)
+
+ # Reshape for multi-head
+ q = reshape_copy(q, (seq_len, self.num_heads, self.head_dim))
+ k = reshape_copy(k, (seq_len, self.num_kv_heads, self.head_dim))
+ v = reshape_copy(v, (seq_len, self.num_kv_heads, self.head_dim))
+
+ # QK Norm (Qwen3 style)
+ if self.q_norm is not None:
+ q_shape = (seq_len, self.num_heads, self.head_dim)
+ q_2d = reshape_copy(q, (seq_len * self.num_heads, self.head_dim))
+ q_2d = self.q_norm(q_2d)
+ q = reshape_copy(q_2d, q_shape)
+ if self.k_norm is not None:
+ k_shape = (seq_len, self.num_kv_heads, self.head_dim)
+ k_2d = reshape_copy(k, (seq_len * self.num_kv_heads, self.head_dim))
+ k_2d = self.k_norm(k_2d)
+ k = reshape_copy(k_2d, k_shape)
+
+ # Apply RoPE on GPU
+ if self.config.use_rope:
+ assert self._cos is not None and self._sin is not None
+ from pygpukit.ops.basic import rope_inplace_f32table
+
+ q_dtype = q.dtype
+ cos_f32 = from_numpy(self._cos[position_ids].astype(np.float32))
+ sin_f32 = from_numpy(self._sin[position_ids].astype(np.float32))
+ if q_dtype in (dt_float16, dt_bfloat16):
+ # Use f32 tables directly for higher precision (no intermediate alloc)
+ rope_inplace_f32table(q, k, cos_f32, sin_f32)
+ else:
+ rope_inplace(q, k, cos_f32, sin_f32)
+
+ # GPU KV Cache
+ if past_kv is not None:
+ past_k, past_v = past_kv
+ if isinstance(past_k, GPUArray):
+ k = concat_axis0(past_k, k)
+ v = concat_axis0(past_v, v)
+ else:
+ k_np = k.to_numpy()
+ v_np = v.to_numpy()
+ k_np = np.concatenate([past_k, k_np], axis=0)
+ v_np = np.concatenate([past_v, v_np], axis=0)
+ k = from_numpy(k_np)
+ v = from_numpy(v_np)
+
+ present_kv = (k, v) if use_cache else None
+
+ # Expand for GQA on GPU
+ if self.num_kv_groups > 1:
+ k_expanded = repeat_interleave_axis1(k, self.num_kv_groups)
+ v_expanded = repeat_interleave_axis1(v, self.num_kv_groups)
+ else:
+ k_expanded = k
+ v_expanded = v
+
+ # GPU SDPA
+ q_t = transpose_3d_021(q)
+ k_t = transpose_3d_021(k_expanded)
+ v_t = transpose_3d_021(v_expanded)
+
+ attn_output = sdpa_causal(q_t, k_t, v_t)
+ attn_output = transpose_3d_021(attn_output)
+ attn_output = reshape_copy(attn_output, (seq_len, self.num_heads * self.head_dim))
+
+ return self.o_proj(attn_output), present_kv
+
+ def forward_fixed_cache(
+ self,
+ x: GPUArray,
+ position: int,
+ context_len: int,
+ *,
+ out: GPUArray | None = None,
+ ) -> GPUArray:
+ """Forward pass using fixed-length KV cache (for CUDA Graph decode).
+
+ Args:
+ x: Input tensor [1, hidden_size] - single token
+ position: Current position in sequence (for RoPE and cache update)
+ context_len: Total context length (prefill + decoded so far)
+ out: Optional pre-allocated output buffer
+
+ Returns:
+ Output tensor [1, hidden_size]
+ """
+ assert self._k_cache is not None, "Call init_fixed_cache first"
+ assert x.shape[0] == 1, "forward_fixed_cache expects single token"
+
+ if self.qkv_proj is not None:
+ # Fused QKV projection (faster for non-FP8)
+ qkv = self.qkv_proj(x)
+ q_2d = qkv.narrow(0, self.q_dim)
+ k_2d = qkv.narrow(self.q_dim, self.k_dim)
+ v_2d = qkv.narrow(self.q_dim + self.k_dim, self.v_dim)
+
+ # Apply biases separately
+ if self.q_proj.bias is not None:
+ bias_add_inplace(q_2d, self.q_proj.bias)
+ if self.k_proj.bias is not None:
+ bias_add_inplace(k_2d, self.k_proj.bias)
+ if self.v_proj.bias is not None:
+ bias_add_inplace(v_2d, self.v_proj.bias)
+ else:
+ # Separate projections (for FP8)
+ q_2d = self.q_proj(x)
+ k_2d = self.k_proj(x)
+ v_2d = self.v_proj(x)
+
+ # Zero-copy reshape
+ q = q_2d.view((1, self.num_heads, self.head_dim))
+ k = k_2d.view((1, self.num_kv_heads, self.head_dim))
+ v = v_2d.view((1, self.num_kv_heads, self.head_dim))
+
+ # QK Norm
+ if self.q_norm is not None:
+ q_flat = q.view((self.num_heads, self.head_dim))
+ q_normed = self.q_norm(q_flat)
+ q = q_normed.view((1, self.num_heads, self.head_dim))
+ if self.k_norm is not None:
+ k_flat = k.view((self.num_kv_heads, self.head_dim))
+ k_normed = self.k_norm(k_flat)
+ k = k_normed.view((1, self.num_kv_heads, self.head_dim))
+
+ q_dtype = q.dtype
+
+ # Apply RoPE
+ if self.config.use_rope and self._cos is not None and self._sin is not None:
+ from pygpukit.ops.basic import rope_inplace_f32table
+
+ cos_f32 = from_numpy(self._cos[position : position + 1].astype(np.float32))
+ sin_f32 = from_numpy(self._sin[position : position + 1].astype(np.float32))
+ if q_dtype in (dt_float16, dt_bfloat16):
+ rope_inplace_f32table(q, k, cos_f32, sin_f32)
+ else:
+ rope_inplace(q, k, cos_f32, sin_f32)
+
+ # Update KV cache
+ kv_cache_update_gqa(k, self._k_cache, self.num_heads, position)
+ kv_cache_update_gqa(v, self._v_cache, self.num_heads, position)
+
+ q_t = q.view((self.num_heads, 1, self.head_dim))
+
+ # Allocate output buffer if needed
+ if out is None:
+ if q_dtype == dt_float16:
+ out_np_dtype = np.float16
+ elif q_dtype == dt_bfloat16:
+ out_np_dtype = np.uint16
+ else:
+ out_np_dtype = np.float32
+ attn_out = from_numpy(np.zeros((self.num_heads, 1, self.head_dim), dtype=out_np_dtype))
+ else:
+ attn_out = out
+
+ sdpa_causal_fixed_cache(q_t, self._k_cache, self._v_cache, attn_out, context_len)
+
+ attn_output = attn_out.view((1, self.num_heads * self.head_dim))
+ return self.o_proj(attn_output)
+
+ def forward_fixed_cache_batch(
+ self,
+ x: GPUArray,
+ start_position: int,
+ context_len: int,
+ ) -> GPUArray:
+ """Forward pass for batch decode using fixed-length KV cache.
+
+ Processes multiple tokens at once for speculative decoding verification.
+ """
+ assert self._k_cache is not None, "Call init_fixed_cache first"
+ seq_len = x.shape[0]
+
+ if seq_len == 1:
+ return self.forward_fixed_cache(x, start_position, context_len)
+
+ if self.qkv_proj is not None:
+ # Fused QKV projection (faster for non-FP8)
+ qkv = self.qkv_proj(x)
+ qkv_np = qkv.to_numpy()
+ q_np = qkv_np[:, : self.q_dim]
+ k_np = qkv_np[:, self.q_dim : self.q_dim + self.k_dim]
+ v_np = qkv_np[:, self.q_dim + self.k_dim :]
+
+ # Apply biases
+ if self.q_proj.bias is not None:
+ q_np = q_np + self.q_proj.bias.to_numpy()
+ if self.k_proj.bias is not None:
+ k_np = k_np + self.k_proj.bias.to_numpy()
+ if self.v_proj.bias is not None:
+ v_np = v_np + self.v_proj.bias.to_numpy()
+
+ q_2d = from_numpy(q_np.astype(qkv_np.dtype))
+ k_2d = from_numpy(k_np.astype(qkv_np.dtype))
+ v_2d = from_numpy(v_np.astype(qkv_np.dtype))
+ else:
+ # Separate projections (for FP8)
+ q_2d = self.q_proj(x)
+ k_2d = self.k_proj(x)
+ v_2d = self.v_proj(x)
+
+ q = reshape_copy(q_2d, (seq_len, self.num_heads, self.head_dim))
+ k = reshape_copy(k_2d, (seq_len, self.num_kv_heads, self.head_dim))
+ v = reshape_copy(v_2d, (seq_len, self.num_kv_heads, self.head_dim))
+
+ # QK Norm
+ if self.q_norm is not None:
+ q_flat = reshape_copy(q, (seq_len * self.num_heads, self.head_dim))
+ q_normed = self.q_norm(q_flat)
+ q = reshape_copy(q_normed, (seq_len, self.num_heads, self.head_dim))
+ if self.k_norm is not None:
+ k_flat = reshape_copy(k, (seq_len * self.num_kv_heads, self.head_dim))
+ k_normed = self.k_norm(k_flat)
+ k = reshape_copy(k_normed, (seq_len, self.num_kv_heads, self.head_dim))
+
+ q_dtype = q.dtype
+
+ # RoPE
+ if self.config.use_rope and self._cos is not None and self._sin is not None:
+ from pygpukit.ops.basic import rope_inplace_f32table
+
+ end_pos = start_position + seq_len
+ cos_f32 = from_numpy(self._cos[start_position:end_pos].astype(np.float32))
+ sin_f32 = from_numpy(self._sin[start_position:end_pos].astype(np.float32))
+ if q_dtype in (dt_float16, dt_bfloat16):
+ rope_inplace_f32table(q, k, cos_f32, sin_f32)
+ else:
+ rope_inplace(q, k, cos_f32, sin_f32)
+
+ # Update KV cache
+ kv_cache_prefill_gqa(k, self._k_cache, self.num_heads, start_position)
+ kv_cache_prefill_gqa(v, self._v_cache, self.num_heads, start_position)
+
+ q_t = transpose_3d_021(q)
+ # Allocate attn_out with matching dtype
+ if q_dtype == dt_float16:
+ out_np_dtype = np.float16
+ elif q_dtype == dt_bfloat16:
+ out_np_dtype = np.uint16 # bfloat16 stored as uint16
+ else:
+ out_np_dtype = np.float32
+ attn_out = from_numpy(
+ np.zeros((self.num_heads, seq_len, self.head_dim), dtype=out_np_dtype)
+ )
+
+ sdpa_causal_fixed_cache(q_t, self._k_cache, self._v_cache, attn_out, context_len)
+
+ attn_output = transpose_3d_021(attn_out)
+ attn_output = reshape_copy(attn_output, (seq_len, self.num_heads * self.head_dim))
+ return self.o_proj(attn_output)
+
+ def forward_fixed_cache_batch_zero_alloc(
+ self,
+ x: GPUArray,
+ start_position: int,
+ context_len: int,
+ buffers: DecodeBuffers,
+ rope_cos_gpu: GPUArray | None,
+ rope_sin_gpu: GPUArray | None,
+ start_pos_buf: GPUArray,
+ ) -> GPUArray:
+ """Zero-allocation forward pass for batch decode using fixed-length KV cache.
+
+ This version uses pre-allocated buffers for all operations, making it
+ compatible with CUDA Graph capture. No memory allocations occur.
+ """
+ assert self._k_cache is not None, "Call init_fixed_cache first"
+ seq_len = x.shape[0]
+
+ q_out = buffers.q_batch.view((seq_len, self.num_heads, self.head_dim))
+ k_out = buffers.k_batch.view((seq_len, self.num_kv_heads, self.head_dim))
+ v_out = buffers.v_batch.view((seq_len, self.num_kv_heads, self.head_dim))
+
+ if self.qkv_proj is not None:
+ # Fused QKV projection into pre-allocated buffer
+ qkv_out = buffers.qkv_proj_out_batch.slice_rows(seq_len)
+ self.qkv_proj(x, out=qkv_out)
+
+ # Split QKV
+ split_qkv_batch(qkv_out, q_out, k_out, v_out, self.q_dim, self.k_dim, self.v_dim)
+
+ # Apply biases
+ if self.q_proj.bias is not None:
+ q_out_2d = q_out.view((seq_len, self.q_dim))
+ bias_add_inplace(q_out_2d, self.q_proj.bias)
+ if self.k_proj.bias is not None:
+ k_out_2d = k_out.view((seq_len, self.k_dim))
+ bias_add_inplace(k_out_2d, self.k_proj.bias)
+ if self.v_proj.bias is not None:
+ v_out_2d = v_out.view((seq_len, self.v_dim))
+ bias_add_inplace(v_out_2d, self.v_proj.bias)
+ else:
+ # Separate projections (for FP8 - allocates, not zero-alloc)
+ q_2d = self.q_proj(x)
+ k_2d = self.k_proj(x)
+ v_2d = self.v_proj(x)
+ copy_to(reshape_copy(q_2d, (seq_len, self.num_heads, self.head_dim)), q_out)
+ copy_to(reshape_copy(k_2d, (seq_len, self.num_kv_heads, self.head_dim)), k_out)
+ copy_to(reshape_copy(v_2d, (seq_len, self.num_kv_heads, self.head_dim)), v_out)
+
+ # QK Norm
+ if self.q_norm is not None and buffers.q_flat_batch is not None:
+ q_flat = buffers.q_flat_batch.slice_rows(seq_len * self.num_heads)
+ copy_to(q_out.view((seq_len * self.num_heads, self.head_dim)), q_flat)
+ rmsnorm(q_flat, self.q_norm.weight, self.q_norm.eps, out=q_flat)
+ copy_to(q_flat.view((seq_len, self.num_heads, self.head_dim)), q_out)
+
+ if self.k_norm is not None and buffers.k_flat_batch is not None:
+ k_flat = buffers.k_flat_batch.slice_rows(seq_len * self.num_kv_heads)
+ copy_to(k_out.view((seq_len * self.num_kv_heads, self.head_dim)), k_flat)
+ rmsnorm(k_flat, self.k_norm.weight, self.k_norm.eps, out=k_flat)
+ copy_to(k_flat.view((seq_len, self.num_kv_heads, self.head_dim)), k_out)
+
+ # RoPE
+ if self.config.use_rope and rope_cos_gpu is not None and rope_sin_gpu is not None:
+ cos_out = buffers.cos_batch.slice_rows(seq_len)
+ sin_out = buffers.sin_batch.slice_rows(seq_len)
+ slice_rows_range_ptr(rope_cos_gpu, cos_out, start_pos_buf, seq_len)
+ slice_rows_range_ptr(rope_sin_gpu, sin_out, start_pos_buf, seq_len)
+ rope_inplace(q_out, k_out, cos_out, sin_out)
+
+ # Update KV cache
+ kv_cache_prefill_gqa(k_out, self._k_cache, self.num_heads, start_position)
+ kv_cache_prefill_gqa(v_out, self._v_cache, self.num_heads, start_position)
+
+ # Transpose Q for SDPA
+ q_t_out = buffers.q_t_batch.view((self.num_heads, seq_len, self.head_dim))
+ transpose_3d_021(q_out, out=q_t_out)
+
+ # SDPA
+ attn_out = buffers.attn_out_batch.view((self.num_heads, seq_len, self.head_dim))
+ sdpa_causal_fixed_cache(q_t_out, self._k_cache, self._v_cache, attn_out, context_len)
+
+ # Transpose output
+ attn_out_t = buffers.attn_out_t_batch.view((seq_len, self.num_heads, self.head_dim))
+ transpose_3d_021(attn_out, out=attn_out_t)
+
+ attn_out_2d = attn_out_t.view((seq_len, self.num_heads * self.head_dim))
+
+ # O projection
+ o_out = buffers.o_proj_out_batch.slice_rows(seq_len)
+ self.o_proj(attn_out_2d, out=o_out)
+
+ return o_out
+
+
+__all__ = [
+ "Attention",
+]
diff --git a/src/pygpukit/llm/layers/block.py b/src/pygpukit/llm/layers/block.py
new file mode 100644
index 0000000..f507bdf
--- /dev/null
+++ b/src/pygpukit/llm/layers/block.py
@@ -0,0 +1,62 @@
+"""Transformer block implementation for PyGPUkit LLM.
+
+Provides:
+- TransformerBlock: Attention + MLP with residual connections
+"""
+
+from __future__ import annotations
+
+from pygpukit.core.array import GPUArray
+from pygpukit.ops.basic import add
+
+from .attention import Attention
+from .mlp import MLP
+from .moe import MoELayer
+from .norm import Norm
+
+
+class TransformerBlock:
+ """Unified transformer block.
+
+ Structure:
+ Norm -> Attention -> Residual
+ Norm -> MLP/MoE -> Residual
+ """
+
+ def __init__(
+ self,
+ attn_norm: Norm,
+ attn: Attention,
+ mlp_norm: Norm,
+ mlp: MLP | MoELayer,
+ ):
+ self.attn_norm = attn_norm
+ self.attn = attn
+ self.mlp_norm = mlp_norm
+ self.mlp = mlp # Can be MLP or MoELayer
+
+ def __call__(
+ self,
+ x: GPUArray,
+ position_ids: list[int] | None = None,
+ past_kv: tuple | None = None,
+ use_cache: bool = False,
+ ) -> tuple[GPUArray, tuple | None]:
+ # Attention block
+ residual = x
+ x = self.attn_norm(x)
+ attn_out, present_kv = self.attn(x, position_ids, past_kv, use_cache)
+ x = add(residual, attn_out)
+
+ # MLP block
+ residual = x
+ x = self.mlp_norm(x)
+ x = self.mlp(x)
+ x = add(residual, x)
+
+ return x, present_kv
+
+
+__all__ = [
+ "TransformerBlock",
+]
diff --git a/src/pygpukit/llm/layers/linear.py b/src/pygpukit/llm/layers/linear.py
new file mode 100644
index 0000000..a59ed65
--- /dev/null
+++ b/src/pygpukit/llm/layers/linear.py
@@ -0,0 +1,267 @@
+"""Linear layer implementations for PyGPUkit LLM.
+
+Provides:
+- LinearBF16: Dense layer with BF16 weights
+- LinearFP8: Dense layer with FP8 weights (online dequantization)
+"""
+
+from __future__ import annotations
+
+import numpy as np
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.dtypes import bfloat16 as dt_bfloat16
+from pygpukit.core.factory import from_numpy, zeros
+from pygpukit.ops.basic import (
+ bias_add_inplace,
+ gemv_bf16,
+ gemv_fp8_bf16,
+ matmul,
+ transpose,
+ w8a16_gemm_sm120,
+)
+
+
+class LinearBF16:
+ """BF16 Linear layer: y = xW^T + b
+
+ Weights are stored as [out_features, in_features] (PyTorch convention).
+
+ For M=1 (single token decode), uses custom GEMV kernel which is 4-6x faster
+ than cuBLASLt matmul. Automatically falls back to matmul for batch > 1.
+ """
+
+ # Class-level flag to enable/disable GEMV optimization
+ _use_gemv: bool = True
+
+ def __init__(self, weight: GPUArray, bias: GPUArray | None = None):
+ if weight.ndim != 2:
+ raise ValueError(f"weight must be 2D, got {weight.ndim}D")
+ self.weight = weight
+ self.bias = bias
+ self.out_features = weight.shape[0]
+ self.in_features = weight.shape[1]
+ self._weight_t: GPUArray | None = None
+
+ def __call__(self, x: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
+ """Forward pass: y = xW^T + b
+
+ Args:
+ x: Input tensor [batch, in_features]
+ out: Optional output buffer [batch, out_features]. If provided,
+ result is written in-place (for CUDA Graph capture).
+ """
+ if x.ndim != 2:
+ raise ValueError(f"input must be 2D [batch, in_features], got {x.ndim}D")
+ if x.shape[1] != self.in_features:
+ raise ValueError(f"input features {x.shape[1]} != weight {self.in_features}")
+
+ if self._weight_t is None:
+ self._weight_t = transpose(self.weight)
+
+ # Use GEMV for M=1 with BF16 (1.3-2.4x faster than matmul)
+ # Skip GEMV when out is provided (CUDA Graph mode) - GEMV allocates internally
+ use_gemv = (
+ LinearBF16._use_gemv
+ and x.shape[0] == 1
+ and x.dtype == dt_bfloat16
+ and out is None # GEMV allocates, not compatible with CUDA Graph
+ )
+
+ if use_gemv:
+ # GEMV path for M=1 decode
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ x_1d = x.view((self.in_features,))
+
+ # Use optimized kernel (SM80+) with B[N,K] layout
+ if native.gemv_bf16_opt_available():
+ y_1d = zeros((self.out_features,), dtype="bfloat16")
+ # gemv_bf16_opt: A[K] @ B[N,K]^T -> C[N]
+ native.gemv_bf16_opt_sm120(
+ x_1d._get_native(),
+ self.weight._get_native(), # [N, K] - no transpose
+ y_1d._get_native(),
+ )
+ else:
+ # Fallback: old kernel with B[K,N] layout
+ y_1d = gemv_bf16(x_1d, self._weight_t)
+
+ y = y_1d.view((1, self.out_features))
+ else:
+ # Standard matmul path
+ y = matmul(x, self._weight_t, out=out)
+
+ if self.bias is not None:
+ bias_add_inplace(y, self.bias)
+
+ return y
+
+
+# Backward compatibility alias
+Linear = LinearBF16
+
+
+class LinearFP8:
+ """FP8 Linear layer with online dequantization: y = x @ dequant(W)^T + b
+
+ Stores weights in FP8 E4M3 format with block-wise scaling factors.
+ Dequantizes on-the-fly during forward pass using CUDA kernel.
+
+ Memory savings: 50% vs BF16 (1 byte vs 2 bytes per weight + small scale overhead)
+
+ For M=1 (single token decode), uses FP8 GEMV kernel with online dequantization.
+ For larger batches, falls back to CPU dequantization + GPU matmul.
+ """
+
+ # Class-level flag to enable/disable GEMV optimization
+ _use_gemv: bool = True
+
+ # FP8 E4M3 to float32 lookup table (for CPU fallback)
+ _FP8_TABLE: np.ndarray | None = None
+
+ @classmethod
+ def _get_fp8_table(cls) -> np.ndarray:
+ """Build FP8 E4M3 to float32 conversion lookup table."""
+ if cls._FP8_TABLE is not None:
+ return cls._FP8_TABLE
+
+ table = np.zeros(256, dtype=np.float32)
+ for i in range(256):
+ sign = (i >> 7) & 1
+ exp = (i >> 3) & 0xF
+ mant = i & 0x7
+
+ if exp == 0xF and mant == 0x7:
+ table[i] = np.nan
+ elif exp == 0:
+ value = (mant / 8.0) * (2.0**-6)
+ table[i] = -value if sign else value
+ else:
+ value = (1.0 + mant / 8.0) * (2.0 ** (exp - 7))
+ table[i] = -value if sign else value
+
+ cls._FP8_TABLE = table
+ return table
+
+ def __init__(
+ self,
+ weight_fp8: GPUArray, # [out_features, in_features] as uint8
+ scale_inv: GPUArray, # [out_features // block_h, in_features // block_w] as bf16
+ bias: GPUArray | None = None,
+ block_size: tuple[int, int] = (128, 128),
+ ):
+ if weight_fp8.ndim != 2:
+ raise ValueError(f"weight must be 2D, got {weight_fp8.ndim}D")
+ self.weight_fp8 = weight_fp8
+ self.scale_inv = scale_inv
+ self.bias = bias
+ self.block_size = block_size
+ self.out_features = weight_fp8.shape[0]
+ self.in_features = weight_fp8.shape[1]
+
+ # Transposed weight for GEMV: [in_features, out_features]
+ # FP8 GEMV expects B[K,N] where K=in_features, N=out_features
+ self._weight_fp8_t: GPUArray | None = None
+ self._scale_inv_t: GPUArray | None = None
+
+ # Cached dequantized weight for fallback (lazy initialization)
+ self._weight_dequant: GPUArray | None = None
+ self._weight_dequant_t: GPUArray | None = None
+
+ def _ensure_transposed_fp8(self) -> None:
+ """Ensure transposed FP8 weight is available for GEMV."""
+ if self._weight_fp8_t is None:
+ # Transpose weight: [out, in] -> [in, out]
+ self._weight_fp8_t = transpose(self.weight_fp8)
+ # Transpose scale: [out/128, in/128] -> [in/128, out/128]
+ self._scale_inv_t = transpose(self.scale_inv)
+
+ def _dequantize_cpu(self) -> np.ndarray:
+ """Dequantize FP8 weight to float32 on CPU."""
+ table = self._get_fp8_table()
+
+ # Get FP8 bytes
+ fp8_np = self.weight_fp8.to_numpy()
+ if fp8_np.dtype != np.uint8:
+ fp8_np = fp8_np.view(np.uint8)
+
+ # Convert to float32
+ f32 = table[fp8_np.ravel()].reshape(fp8_np.shape)
+
+ # Get scale_inv (bf16 as uint16)
+ scale_np = self.scale_inv.to_numpy()
+ if scale_np.dtype == np.uint16:
+ scale_f32 = np.empty(scale_np.shape, dtype=np.float32)
+ scale_f32.view(np.uint32)[:] = scale_np.astype(np.uint32) << 16
+ else:
+ scale_f32 = scale_np.astype(np.float32)
+
+ # Apply block-wise scaling
+ H, W = f32.shape
+ block_h, block_w = self.block_size
+ num_blocks_h = H // block_h
+ num_blocks_w = W // block_w
+
+ f32_reshaped = f32.reshape(num_blocks_h, block_h, num_blocks_w, block_w)
+ scale_expanded = scale_f32[:, np.newaxis, :, np.newaxis]
+ f32_scaled = f32_reshaped * scale_expanded
+
+ return f32_scaled.reshape(H, W)
+
+ def _ensure_dequantized(self) -> None:
+ """Ensure dequantized weight is available (lazy init, for fallback)."""
+ if self._weight_dequant is None:
+ # Dequantize on CPU and upload to GPU
+ weight_f32 = self._dequantize_cpu()
+
+ # Convert to BF16
+ uint32_view = weight_f32.view(np.uint32)
+ weight_bf16 = ((uint32_view + 0x7FFF + ((uint32_view >> 16) & 1)) >> 16).astype(
+ np.uint16
+ )
+
+ self._weight_dequant = from_numpy(weight_bf16)
+ self._weight_dequant_t = transpose(self._weight_dequant)
+
+ def __call__(self, x: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
+ """Forward pass with online dequantization.
+
+ For M=1 (single token), uses FP8 GEMV kernel with online dequantization.
+ For M>1, uses batched FP8 GEMV kernel.
+ """
+ if x.ndim != 2:
+ raise ValueError(f"input must be 2D [batch, in_features], got {x.ndim}D")
+ if x.shape[1] != self.in_features:
+ raise ValueError(f"input features {x.shape[1]} != weight {self.in_features}")
+
+ M = x.shape[0]
+
+ if M == 1 and self._use_gemv:
+ # M=1 path: Use FP8 GEMV kernel with B[N,K] layout (no transpose needed)
+ x_1d = x.view((self.in_features,))
+
+ if out is not None:
+ out_1d = out.view((self.out_features,))
+ gemv_fp8_bf16(x_1d, self.weight_fp8, self.scale_inv, out=out_1d)
+ y = out
+ else:
+ y_1d = gemv_fp8_bf16(x_1d, self.weight_fp8, self.scale_inv)
+ y = y_1d.view((1, self.out_features))
+ else:
+ # M>1 path: Use W8A16 GEMM with FP8 TensorCore (requires transposed weights)
+ self._ensure_transposed_fp8()
+ y = w8a16_gemm_sm120(x, self._weight_fp8_t, self._scale_inv_t, out=out)
+
+ if self.bias is not None:
+ bias_add_inplace(y, self.bias)
+
+ return y
+
+
+__all__ = [
+ "LinearBF16",
+ "LinearFP8",
+ "Linear",
+]
diff --git a/src/pygpukit/llm/layers/mlp.py b/src/pygpukit/llm/layers/mlp.py
new file mode 100644
index 0000000..f423758
--- /dev/null
+++ b/src/pygpukit/llm/layers/mlp.py
@@ -0,0 +1,103 @@
+"""MLP layer implementation for PyGPUkit LLM.
+
+Provides:
+- MLP: Unified MLP supporting GELU and SwiGLU activations
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from pygpukit.core.array import GPUArray
+from pygpukit.ops.basic import (
+ concat_axis0,
+ gelu,
+ mul,
+ silu,
+)
+
+from .linear import LinearBF16, LinearFP8
+
+if TYPE_CHECKING:
+ from pygpukit.llm.config import TransformerConfig
+
+
+class MLP:
+ """Unified MLP supporting GELU and SwiGLU activations.
+
+ GELU (GPT-2 style):
+ fc1 -> GELU -> fc2
+
+ SwiGLU (LLaMA style):
+ gate_proj -> SiLU -> * up_proj -> down_proj
+
+ Supports FP8 quantized weights via LinearFP8.
+ """
+
+ def __init__(
+ self,
+ config: TransformerConfig,
+ # GELU path weights (GPUArray or LinearBF16/LinearFP8)
+ fc1_weight: GPUArray | LinearBF16 | LinearFP8 | None = None,
+ fc1_bias: GPUArray | None = None,
+ fc2_weight: GPUArray | LinearBF16 | LinearFP8 | None = None,
+ fc2_bias: GPUArray | None = None,
+ # SwiGLU path weights (GPUArray or LinearBF16/LinearFP8)
+ gate_proj: GPUArray | LinearBF16 | LinearFP8 | None = None,
+ up_proj: GPUArray | LinearBF16 | LinearFP8 | None = None,
+ down_proj: GPUArray | LinearBF16 | LinearFP8 | None = None,
+ ):
+ self.config = config
+ self.activation = config.activation
+
+ # Helper to wrap GPUArray in LinearBF16, or use pre-built LinearBF16/LinearFP8
+ def wrap_linear(
+ proj: GPUArray | LinearBF16 | LinearFP8 | None, bias: GPUArray | None = None
+ ) -> LinearBF16 | LinearFP8 | None:
+ if proj is None:
+ return None
+ if isinstance(proj, (LinearBF16, LinearFP8)):
+ return proj
+ return LinearBF16(proj, bias)
+
+ if config.activation == "gelu":
+ if fc1_weight is None or fc2_weight is None:
+ raise ValueError("GELU MLP requires fc1_weight and fc2_weight")
+ self.fc1 = wrap_linear(fc1_weight, fc1_bias)
+ self.fc2 = wrap_linear(fc2_weight, fc2_bias)
+ else: # silu (SwiGLU)
+ if gate_proj is None or up_proj is None or down_proj is None:
+ raise ValueError("SwiGLU MLP requires gate_proj, up_proj, down_proj")
+
+ self.gate_proj = wrap_linear(gate_proj)
+ self.up_proj = wrap_linear(up_proj)
+ self.down_proj = wrap_linear(down_proj)
+
+ # Get intermediate size from the projection
+ if isinstance(gate_proj, (LinearBF16, LinearFP8)):
+ self.intermediate_size = gate_proj.out_features
+ else:
+ self.intermediate_size = gate_proj.shape[0]
+
+ # Fused gate_up projection only for non-FP8 (GPUArray) weights
+ # FP8 weights can't be concatenated trivially
+ if isinstance(gate_proj, GPUArray) and isinstance(up_proj, GPUArray):
+ gate_up_weight = concat_axis0(gate_proj, up_proj)
+ self.gate_up_proj: LinearBF16 | None = LinearBF16(gate_up_weight, None)
+ else:
+ self.gate_up_proj = None
+
+ def __call__(self, x: GPUArray) -> GPUArray:
+ if self.activation == "gelu":
+ h = self.fc1(x)
+ h = gelu(h)
+ return self.fc2(h)
+ else:
+ gate = silu(self.gate_proj(x))
+ up = self.up_proj(x)
+ return self.down_proj(mul(gate, up))
+
+
+__all__ = [
+ "MLP",
+]
diff --git a/src/pygpukit/llm/layers/moe.py b/src/pygpukit/llm/layers/moe.py
new file mode 100644
index 0000000..d6a2695
--- /dev/null
+++ b/src/pygpukit/llm/layers/moe.py
@@ -0,0 +1,458 @@
+"""Mixture of Experts layer implementation for PyGPUkit LLM.
+
+Provides:
+- MoELayer: Mixture of Experts for Mixtral-style models
+"""
+
+from __future__ import annotations
+
+import os
+import time
+from functools import reduce
+from typing import TYPE_CHECKING
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.factory import zeros
+from pygpukit.ops.basic import (
+ concat_axis0,
+ mul,
+ silu,
+)
+
+from .linear import LinearBF16, LinearFP8
+from .mlp import MLP
+
+if TYPE_CHECKING:
+ from pygpukit.llm.config import TransformerConfig
+
+
+class MoELayer:
+ """Mixture of Experts layer for Mixtral-style models.
+
+ Architecture:
+ 1. Router: hidden -> [num_experts] logits
+ 2. Top-K selection with softmax
+ 3. Expert FFN (SwiGLU) for each selected expert
+ 4. Weighted combination of expert outputs
+
+ Supports FP8 quantized expert weights via LinearFP8.
+ """
+
+ def __init__(
+ self,
+ config: TransformerConfig,
+ gate_weight: GPUArray, # [num_experts, hidden_size] - router
+ expert_weights: list, # [(gate, up, down), ...] - GPUArray or LinearBF16/LinearFP8
+ ):
+ self.config = config
+ self.num_experts = config.num_experts or len(expert_weights)
+ self.num_experts_per_tok = config.num_experts_per_tok
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.moe_intermediate_size or config.intermediate_size
+
+ # Router (gate) projection
+ self.gate = LinearBF16(gate_weight)
+
+ # Expert FFNs
+ self.experts: list[MLP] = []
+ for gate_proj, up_proj, down_proj in expert_weights:
+ expert = MLP(
+ config,
+ gate_proj=gate_proj,
+ up_proj=up_proj,
+ down_proj=down_proj,
+ )
+ self.experts.append(expert)
+
+ # Check if all experts use FP8 weights for grouped GEMM optimization
+ self._use_grouped_gemm = False
+ self._stacked_gate_weight: GPUArray | None = None
+ self._stacked_gate_scale: GPUArray | None = None
+ self._stacked_up_weight: GPUArray | None = None
+ self._stacked_up_scale: GPUArray | None = None
+ self._stacked_down_weight: GPUArray | None = None
+ self._stacked_down_scale: GPUArray | None = None
+
+ # Check if first expert uses FP8 - use grouped GEMM v2 for optimization
+ # TEMP: Disabled for debugging
+ if os.environ.get("PYGPUKIT_DISABLE_GROUPED_GEMM") != "1":
+ if len(self.experts) > 0 and isinstance(self.experts[0].gate_proj, LinearFP8):
+ self._stack_fp8_weights()
+
+ # Profiling flag (set to True to enable timing)
+ _profile: bool = True
+ _profile_count: int = 0
+
+ def _stack_fp8_weights(self) -> None:
+ """Stack FP8 expert weights for grouped GEMM optimization."""
+ # Collect weights from all experts
+ gate_weights = []
+ gate_scales = []
+ up_weights = []
+ up_scales = []
+ down_weights = []
+ down_scales = []
+
+ for expert in self.experts:
+ if not isinstance(expert.gate_proj, LinearFP8):
+ return # Not all experts are FP8, abort
+
+ gate_weights.append(expert.gate_proj.weight_fp8)
+ gate_scales.append(expert.gate_proj.scale_inv)
+ up_weights.append(expert.up_proj.weight_fp8)
+ up_scales.append(expert.up_proj.scale_inv)
+ down_weights.append(expert.down_proj.weight_fp8)
+ down_scales.append(expert.down_proj.scale_inv)
+
+ # Stack weights: [num_experts, N, K]
+ # gate_proj: [intermediate_size, hidden_size] -> stacked [num_experts, intermediate_size, hidden_size]
+ # Each weight is [N, K], stack along new axis 0
+
+ def stack_arrays_fast(arrays: list[GPUArray]) -> GPUArray:
+ """Stack arrays along new axis 0 using single allocation + cudaMemcpy."""
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+
+ # Get shape info from first array
+ first = arrays[0]
+ num_arrays = len(arrays)
+ inner_shape = first.shape # [N, K] or [N/128, K/128]
+
+ # Calculate strides (nbytes is property, not method)
+ bytes_per_array = first._get_native().nbytes
+
+ # Allocate output: [num_arrays, *inner_shape]
+ out_shape = [num_arrays] + list(inner_shape)
+ out_native = native.empty(out_shape, first._get_native().dtype)
+ out = GPUArray._wrap_native(out_native)
+
+ # Copy each array to its slice using cuMemcpy
+ for i, arr in enumerate(arrays):
+ offset_bytes = i * bytes_per_array
+ native.memcpy_device_to_device_offset(
+ arr._get_native(),
+ out._get_native(),
+ 0, # src offset
+ offset_bytes, # dst offset
+ bytes_per_array,
+ )
+
+ return out
+
+ self._stacked_gate_weight = stack_arrays_fast(gate_weights)
+ self._stacked_gate_scale = stack_arrays_fast(gate_scales)
+ self._stacked_up_weight = stack_arrays_fast(up_weights)
+ self._stacked_up_scale = stack_arrays_fast(up_scales)
+ self._stacked_down_weight = stack_arrays_fast(down_weights)
+ self._stacked_down_scale = stack_arrays_fast(down_scales)
+
+ self._use_grouped_gemm = True
+ print(f"[MoE] Stacked {self.num_experts} expert weights for grouped GEMM")
+
+ def __call__(self, x: GPUArray) -> GPUArray:
+ """Forward pass through MoE layer.
+
+ Args:
+ x: Input tensor [batch, seq, hidden_size] or [seq, hidden_size]
+
+ Returns:
+ Output tensor with same shape as input
+ """
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+
+ profile = self._profile and MoELayer._profile_count < 3
+ if profile:
+ native.device_synchronize()
+ t0 = time.perf_counter()
+
+ original_shape = x.shape
+ # Flatten to [num_tokens, hidden_size]
+ if len(original_shape) == 3:
+ batch, seq, hidden = original_shape
+ num_tokens = batch * seq
+ x = x.reshape(num_tokens, hidden)
+ else:
+ num_tokens, hidden = original_shape
+
+ k = self.num_experts_per_tok
+
+ # Step 1: Compute router logits
+ router_logits = self.gate(x) # [num_tokens, num_experts]
+ if profile:
+ native.device_synchronize()
+ t1 = time.perf_counter()
+
+ # Step 2: Top-K selection
+ router_weights = zeros((num_tokens, k), dtype=x.dtype)
+ expert_indices = zeros((num_tokens, k), dtype="int32")
+ native.moe_topk_with_indices(
+ router_logits._get_native(),
+ router_weights._get_native(),
+ expert_indices._get_native(),
+ k,
+ )
+
+ # Step 3: Softmax over selected experts
+ native.moe_softmax_topk(router_weights._get_native(), k)
+
+ # Step 4: Compute permutation for efficient expert dispatch
+ expert_counts = zeros((self.num_experts,), dtype="int32")
+ expert_offsets = zeros((self.num_experts + 1,), dtype="int32")
+ permute_indices = zeros((num_tokens * k,), dtype="int32")
+ reverse_perm = zeros((num_tokens * k,), dtype="int32")
+ native.moe_compute_permutation(
+ expert_indices._get_native(),
+ expert_counts._get_native(),
+ expert_offsets._get_native(),
+ permute_indices._get_native(),
+ reverse_perm._get_native(),
+ self.num_experts,
+ k,
+ )
+
+ # Step 5: Gather hidden states for experts
+ gathered = zeros((num_tokens * k, hidden), dtype=x.dtype)
+ native.moe_gather(
+ x._get_native(),
+ permute_indices._get_native(),
+ gathered._get_native(),
+ k,
+ )
+ if profile:
+ native.device_synchronize()
+ t2 = time.perf_counter()
+
+ # Step 6: Run experts
+ if self._use_grouped_gemm:
+ # Use grouped GEMM for all experts in single kernel launches
+ from pygpukit.ops.matmul import grouped_gemm_fp8_bf16
+
+ # Create row_expert_ids from expert_offsets
+ M_total = num_tokens * k
+ row_expert_ids = zeros((M_total,), dtype="int32")
+ native.moe_expand_expert_offsets(
+ expert_offsets._get_native(),
+ row_expert_ids._get_native(),
+ self.num_experts,
+ )
+
+ # gate_proj: gathered[M_total, hidden] @ gate_weight[experts, inter, hidden]^T
+ gate_out = grouped_gemm_fp8_bf16(
+ gathered,
+ self._stacked_gate_weight,
+ self._stacked_gate_scale,
+ row_expert_ids,
+ )
+
+ # up_proj: gathered[M_total, hidden] @ up_weight[experts, inter, hidden]^T
+ up_out = grouped_gemm_fp8_bf16(
+ gathered,
+ self._stacked_up_weight,
+ self._stacked_up_scale,
+ row_expert_ids,
+ )
+
+ # SiLU(gate) * up
+ intermediate = mul(silu(gate_out), up_out)
+
+ # down_proj: intermediate[M_total, inter] @ down_weight[experts, hidden, inter]^T
+ expert_outputs = grouped_gemm_fp8_bf16(
+ intermediate,
+ self._stacked_down_weight,
+ self._stacked_down_scale,
+ row_expert_ids,
+ )
+ else:
+ # Fallback: Run experts sequentially
+ # Get expert counts on CPU for loop
+ expert_counts_cpu = expert_counts.to_numpy()
+ expert_offsets_cpu = expert_offsets.to_numpy()
+
+ # Build list of (expert_id, start, count) for non-empty experts
+ expert_tasks = []
+ for e in range(self.num_experts):
+ start = int(expert_offsets_cpu[e])
+ count = int(expert_counts_cpu[e])
+ if count > 0:
+ expert_tasks.append((e, start, count))
+
+ def run_expert(task: tuple) -> GPUArray:
+ e, start, count = task
+ expert_input = gathered[start : start + count]
+ return self.experts[e](expert_input)
+
+ # Run experts sequentially
+ expert_output_list = [run_expert(task) for task in expert_tasks]
+
+ # Concatenate all expert outputs on GPU
+ expert_outputs = reduce(concat_axis0, expert_output_list)
+
+ if profile:
+ native.device_synchronize()
+ t3 = time.perf_counter()
+
+ # Step 7: Scatter and combine outputs
+ output = zeros((num_tokens, hidden), dtype=x.dtype)
+ native.moe_scatter(
+ expert_outputs._get_native(),
+ router_weights._get_native(),
+ reverse_perm._get_native(),
+ output._get_native(),
+ k,
+ )
+ if profile:
+ native.device_synchronize()
+ t4 = time.perf_counter()
+ MoELayer._profile_count += 1
+ print(
+ f"[MoE Profile] router={t1 - t0:.3f}s, routing={t2 - t1:.3f}s, experts={t3 - t2:.3f}s, scatter={t4 - t3:.3f}s"
+ )
+
+ # Reshape back
+ if len(original_shape) == 3:
+ output = output.reshape(*original_shape)
+
+ return output
+
+ def forward_zero_alloc(
+ self,
+ x: GPUArray,
+ router_logits: GPUArray,
+ router_weights: GPUArray,
+ expert_indices: GPUArray,
+ expert_counts: GPUArray,
+ expert_offsets: GPUArray,
+ permute_indices: GPUArray,
+ reverse_perm: GPUArray,
+ row_expert_ids: GPUArray,
+ gathered: GPUArray,
+ gate_out: GPUArray,
+ up_out: GPUArray,
+ intermediate: GPUArray,
+ expert_outputs: GPUArray,
+ output: GPUArray,
+ ) -> GPUArray:
+ """Zero-allocation forward pass for CUDA Graph support.
+
+ This method uses pre-allocated buffers from DecodeBuffers to avoid
+ any memory allocations during forward pass, enabling CUDA Graph capture.
+
+ Args:
+ x: Input tensor [1, hidden_size]
+ router_logits: Pre-allocated [1, num_experts]
+ router_weights: Pre-allocated [1, k]
+ expert_indices: Pre-allocated [1, k] int32
+ expert_counts: Pre-allocated [num_experts] int32
+ expert_offsets: Pre-allocated [num_experts + 1] int32
+ permute_indices: Pre-allocated [k] int32
+ reverse_perm: Pre-allocated [k] int32
+ row_expert_ids: Pre-allocated [k] int32
+ gathered: Pre-allocated [k, hidden_size]
+ gate_out: Pre-allocated [k, moe_intermediate_size]
+ up_out: Pre-allocated [k, moe_intermediate_size]
+ intermediate: Pre-allocated [k, moe_intermediate_size]
+ expert_outputs: Pre-allocated [k, hidden_size]
+ output: Pre-allocated [1, hidden_size]
+
+ Returns:
+ The output tensor (same as output parameter)
+ """
+ from pygpukit.core.backend import get_native_module
+ from pygpukit.ops.elementwise import mul
+ from pygpukit.ops.matmul import grouped_gemm_fp8_bf16
+ from pygpukit.ops.nn import silu
+
+ native = get_native_module()
+
+ k = self.num_experts_per_tok
+
+ # Step 1: Router forward (gate projection)
+ self.gate(x, out=router_logits)
+
+ # Step 2: Top-K selection (writes to router_weights and expert_indices)
+ native.moe_topk_with_indices(
+ router_logits._get_native(),
+ router_weights._get_native(),
+ expert_indices._get_native(),
+ k,
+ )
+
+ # Step 3: Softmax over selected experts (in-place)
+ native.moe_softmax_topk(router_weights._get_native(), k)
+
+ # Step 4: Compute permutation
+ native.moe_compute_permutation(
+ expert_indices._get_native(),
+ expert_counts._get_native(),
+ expert_offsets._get_native(),
+ permute_indices._get_native(),
+ reverse_perm._get_native(),
+ self.num_experts,
+ k,
+ )
+
+ # Step 5: Gather hidden states
+ native.moe_gather(
+ x._get_native(),
+ permute_indices._get_native(),
+ gathered._get_native(),
+ k,
+ )
+
+ # Step 6: Create row_expert_ids for grouped GEMM
+ native.moe_expand_expert_offsets(
+ expert_offsets._get_native(),
+ row_expert_ids._get_native(),
+ self.num_experts,
+ )
+
+ # Step 7: Expert computation with grouped GEMM
+ # gate_proj: gathered[k, hidden] @ gate_weight[experts, inter, hidden]^T
+ grouped_gemm_fp8_bf16(
+ gathered,
+ self._stacked_gate_weight,
+ self._stacked_gate_scale,
+ row_expert_ids,
+ out=gate_out,
+ )
+
+ # up_proj: gathered[k, hidden] @ up_weight[experts, inter, hidden]^T
+ grouped_gemm_fp8_bf16(
+ gathered,
+ self._stacked_up_weight,
+ self._stacked_up_scale,
+ row_expert_ids,
+ out=up_out,
+ )
+
+ # SiLU(gate) * up -> intermediate
+ silu(gate_out, out=intermediate)
+ mul(intermediate, up_out, out=intermediate)
+
+ # down_proj: intermediate[k, inter] @ down_weight[experts, hidden, inter]^T
+ grouped_gemm_fp8_bf16(
+ intermediate,
+ self._stacked_down_weight,
+ self._stacked_down_scale,
+ row_expert_ids,
+ out=expert_outputs,
+ )
+
+ # Step 8: Scatter and combine outputs
+ native.moe_scatter(
+ expert_outputs._get_native(),
+ router_weights._get_native(),
+ reverse_perm._get_native(),
+ output._get_native(),
+ k,
+ )
+
+ return output
+
+
+__all__ = [
+ "MoELayer",
+]
diff --git a/src/pygpukit/llm/layers/norm.py b/src/pygpukit/llm/layers/norm.py
new file mode 100644
index 0000000..90e1dbc
--- /dev/null
+++ b/src/pygpukit/llm/layers/norm.py
@@ -0,0 +1,44 @@
+"""Normalization layer implementations for PyGPUkit LLM.
+
+Provides:
+- Norm: Unified RMSNorm and LayerNorm
+"""
+
+from __future__ import annotations
+
+from typing import Literal
+
+from pygpukit.core.array import GPUArray
+from pygpukit.ops.basic import (
+ layernorm,
+ rmsnorm,
+)
+
+
+class Norm:
+ """Unified normalization layer supporting RMSNorm and LayerNorm."""
+
+ def __init__(
+ self,
+ weight: GPUArray,
+ bias: GPUArray | None = None,
+ norm_type: Literal["rmsnorm", "layernorm"] = "rmsnorm",
+ eps: float = 1e-5,
+ ):
+ self.weight = weight
+ self.bias = bias
+ self.norm_type = norm_type
+ self.eps = eps
+
+ def __call__(self, x: GPUArray) -> GPUArray:
+ if self.norm_type == "rmsnorm":
+ return rmsnorm(x, self.weight, self.eps)
+ else:
+ if self.bias is None:
+ raise ValueError("LayerNorm requires bias")
+ return layernorm(x, self.weight, self.bias, self.eps)
+
+
+__all__ = [
+ "Norm",
+]
diff --git a/src/pygpukit/llm/layers/rope.py b/src/pygpukit/llm/layers/rope.py
new file mode 100644
index 0000000..1e58779
--- /dev/null
+++ b/src/pygpukit/llm/layers/rope.py
@@ -0,0 +1,48 @@
+"""Rotary Position Embedding (RoPE) utilities for PyGPUkit LLM.
+
+Provides:
+- precompute_freqs_cis: Precompute RoPE cos/sin tables
+- apply_rotary_pos_emb_numpy: Apply RoPE on CPU (numpy)
+"""
+
+from __future__ import annotations
+
+import numpy as np
+
+
+def precompute_freqs_cis(
+ head_dim: int, max_seq_len: int, theta: float = 10000.0
+) -> tuple[np.ndarray, np.ndarray]:
+ """Precompute rotary embedding cos/sin tables."""
+ freqs = 1.0 / (theta ** (np.arange(0, head_dim, 2, dtype=np.float32) / head_dim))
+ t = np.arange(max_seq_len, dtype=np.float32)
+ freqs = np.outer(t, freqs)
+ cos = np.cos(freqs)
+ sin = np.sin(freqs)
+ cos = np.concatenate([cos, cos], axis=-1)
+ sin = np.concatenate([sin, sin], axis=-1)
+ return cos, sin
+
+
+def apply_rotary_pos_emb_numpy(
+ q: np.ndarray, k: np.ndarray, cos: np.ndarray, sin: np.ndarray
+) -> tuple[np.ndarray, np.ndarray]:
+ """Apply rotary position embeddings to Q and K (numpy version)."""
+
+ def rotate_half(x: np.ndarray) -> np.ndarray:
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return np.concatenate([-x2, x1], axis=-1)
+
+ cos = cos[:, np.newaxis, :]
+ sin = sin[:, np.newaxis, :]
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+__all__ = [
+ "precompute_freqs_cis",
+ "apply_rotary_pos_emb_numpy",
+]
diff --git a/src/pygpukit/llm/layers/utils.py b/src/pygpukit/llm/layers/utils.py
new file mode 100644
index 0000000..cf411c4
--- /dev/null
+++ b/src/pygpukit/llm/layers/utils.py
@@ -0,0 +1,65 @@
+"""Weight repacking utilities for PyGPUkit LLM.
+
+Provides:
+- repack_weight: Repack weight tensor into contiguous GPU buffer
+- repack_linear: Repack LinearBF16 layer weights in-place
+- repack_norm: Repack Norm layer weights in-place
+"""
+
+from __future__ import annotations
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.factory import from_numpy
+
+from .linear import LinearBF16
+from .norm import Norm
+
+
+def repack_weight(weight: GPUArray) -> GPUArray:
+ """Repack a weight tensor into a new contiguous GPU buffer.
+
+ This fixes performance issues caused by fragmented GPU memory allocation.
+ Weights allocated later during model loading may end up in suboptimal
+ memory regions, causing 7x slower matmul performance.
+
+ Args:
+ weight: Original weight tensor on GPU
+
+ Returns:
+ New GPUArray with same data in freshly allocated contiguous memory
+ """
+ # Copy to CPU, then back to GPU to get fresh allocation
+ # This ensures the new buffer is allocated contiguously
+ weight_np = weight.to_numpy()
+ return from_numpy(weight_np)
+
+
+def repack_linear(linear: LinearBF16) -> None:
+ """Repack a LinearBF16 layer's weight in-place.
+
+ Args:
+ linear: LinearBF16 layer to repack
+ """
+ linear.weight = repack_weight(linear.weight)
+ # Clear transpose cache - will be regenerated on first use
+ linear._weight_t = None
+ if linear.bias is not None:
+ linear.bias = repack_weight(linear.bias)
+
+
+def repack_norm(norm: Norm) -> None:
+ """Repack a Norm layer's weight in-place.
+
+ Args:
+ norm: Norm layer to repack
+ """
+ norm.weight = repack_weight(norm.weight)
+ if norm.bias is not None:
+ norm.bias = repack_weight(norm.bias)
+
+
+__all__ = [
+ "repack_weight",
+ "repack_linear",
+ "repack_norm",
+]
diff --git a/src/pygpukit/llm/loader.py b/src/pygpukit/llm/loader.py
index eb47246..ec30863 100644
--- a/src/pygpukit/llm/loader.py
+++ b/src/pygpukit/llm/loader.py
@@ -5,13 +5,11 @@
- load_gpt2_from_safetensors: GPT-2 specific loader
- load_llama_from_safetensors: LLaMA specific loader
- load_qwen3_from_safetensors: Qwen3 specific loader
-- repack_model_weights: Optimize GPU memory placement
-- FP8 dequantization: Block-wise FP8 E4M3 to BF16/FP16 conversion
+- load_mixtral_from_safetensors: Mixtral MoE specific loader
"""
from __future__ import annotations
-from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
@@ -40,397 +38,21 @@
TransformerBlock,
)
-if TYPE_CHECKING:
- from pygpukit.llm import SafeTensorsFile, ShardedSafeTensorsFile
- from pygpukit.llm.model import CausalTransformerModel
-
-
-# =============================================================================
-# FP8 Quantization Support
-# =============================================================================
-
-
-@dataclass
-class FP8QuantConfig:
- """FP8 quantization configuration from HuggingFace config.json."""
-
- quant_method: str # "fp8"
- fmt: str # "e4m3" or "e5m2"
- weight_block_size: tuple[int, int] # e.g., (128, 128)
- modules_to_not_convert: list[str] # List of module name patterns to skip
-
- @classmethod
- def from_config(cls, config: dict) -> FP8QuantConfig | None:
- """Parse quantization config from HF config.json."""
- qc = config.get("quantization_config")
- if qc is None or qc.get("quant_method") != "fp8":
- return None
-
- block_size = qc.get("weight_block_size", [128, 128])
- return cls(
- quant_method="fp8",
- fmt=qc.get("fmt", "e4m3"),
- weight_block_size=(block_size[0], block_size[1]),
- modules_to_not_convert=qc.get("modules_to_not_convert", []),
- )
-
-
-# =============================================================================
-# QAT/QAD Quantization Support (Issue #115)
-# =============================================================================
-
-
-@dataclass
-class QATQuantConfig:
- """QAT (Quantization-Aware Training) configuration.
-
- Supports models trained with:
- - NVIDIA TensorRT Model Optimizer
- - HuggingFace Optimum
- - PyTorch Quantization
-
- Reference:
- - https://nvidia.github.io/TensorRT-Model-Optimizer/
- - https://developer.nvidia.com/blog/top-5-ai-model-optimization-techniques-for-faster-smarter-inference/
- """
-
- quant_method: str # "qat", "modelopt", "nvfp4", etc.
- quant_algo: str # "FP8", "INT8", "NVFP4", "W8A8", etc.
- group_size: int # Block/group size for quantization
- kv_cache_quant_algo: str | None # KV cache quantization (optional)
- exclude_modules: list[str] # Modules to skip quantization
- producer: str | None # Tool that produced the checkpoint (e.g., "modelopt")
- producer_version: str | None # Version of the producer tool
-
- @classmethod
- def from_config(cls, config: dict) -> QATQuantConfig | None:
- """Parse QAT config from HF config.json or hf_quant_config.json."""
- # Check for TensorRT Model Optimizer format (hf_quant_config.json style)
- if "producer" in config and "quantization" in config:
- producer_info = config.get("producer", {})
- quant_info = config.get("quantization", {})
- return cls(
- quant_method="modelopt",
- quant_algo=quant_info.get("quant_algo", "unknown"),
- group_size=quant_info.get("group_size", 128),
- kv_cache_quant_algo=quant_info.get("kv_cache_quant_algo"),
- exclude_modules=quant_info.get("exclude_modules", []),
- producer=producer_info.get("name"),
- producer_version=producer_info.get("version"),
- )
-
- # Check for HF quantization_config with QAT method
- qc = config.get("quantization_config")
- if qc is None:
- return None
-
- quant_method = qc.get("quant_method", "")
- # QAT methods: "qat", "awq", "gptq", etc. (exclude "fp8" which is handled separately)
- qat_methods = {"qat", "awq", "gptq", "bnb", "modelopt"}
- if quant_method not in qat_methods:
- return None
-
- return cls(
- quant_method=quant_method,
- quant_algo=qc.get("quant_algo", qc.get("bits", "unknown")),
- group_size=qc.get("group_size", qc.get("block_size", 128)),
- kv_cache_quant_algo=qc.get("kv_cache_quant_algo"),
- exclude_modules=qc.get("modules_to_not_convert", []),
- producer=None,
- producer_version=None,
- )
-
-
-# =============================================================================
-# Pruning Support (Issue #115)
-# =============================================================================
-
-
-@dataclass
-class PruningConfig:
- """Pruning configuration for structurally smaller models.
-
- Supports models pruned with:
- - NVIDIA TensorRT Model Optimizer
- - HuggingFace nn_pruning
- - Neural Compressor
-
- Reference:
- - https://github.com/huggingface/nn_pruning
- - https://github.com/NVIDIA/TensorRT-Model-Optimizer
- """
-
- pruning_method: str # "magnitude", "movement", "structured", "unstructured"
- sparsity: float # Target sparsity (0.0 to 1.0)
- pruned_heads: dict[int, list[int]] | None # Layer -> pruned head indices
- is_structured: bool # True if structured pruning (removes entire heads/neurons)
-
- @classmethod
- def from_config(cls, config: dict) -> PruningConfig | None:
- """Parse pruning config from HF config.json."""
- # Check for pruned_heads (HuggingFace standard)
- pruned_heads = config.get("pruned_heads")
- if pruned_heads:
- # Convert string keys to int if needed
- if isinstance(pruned_heads, dict):
- pruned_heads = {int(k): v for k, v in pruned_heads.items()}
- return cls(
- pruning_method="structured",
- sparsity=0.0, # Unknown from config alone
- pruned_heads=pruned_heads,
- is_structured=True,
- )
-
- # Check for pruning_config section
- pc = config.get("pruning_config")
- if pc is None:
- return None
-
- return cls(
- pruning_method=pc.get("pruning_type", pc.get("method", "unknown")),
- sparsity=pc.get("target_sparsity", pc.get("sparsity", 0.0)),
- pruned_heads=pc.get("pruned_heads"),
- is_structured=pc.get("is_structured", pc.get("structured", False)),
- )
-
-
-# =============================================================================
-# Sparsity Pattern Support (Issue #115)
-# =============================================================================
-
-
-@dataclass
-class SparsityConfig:
- """Sparsity pattern configuration for sparse tensor operations.
-
- Supports:
- - 2:4 structured sparsity (Ampere+)
- - Block sparsity patterns
- - Custom sparsity masks
-
- Reference:
- - https://developer.nvidia.com/blog/accelerating-inference-with-sparsity-using-ampere-and-tensorrt/
- """
-
- pattern: str # "2:4", "4:8", "block", "unstructured"
- block_size: tuple[int, int] | None # For block sparsity
- density: float # Non-zero ratio (1 - sparsity)
-
- @classmethod
- def from_config(cls, config: dict) -> SparsityConfig | None:
- """Parse sparsity config from HF config.json."""
- sc = config.get("sparsity_config")
- if sc is None:
- # Check for sparsity in quantization_config
- qc = config.get("quantization_config", {})
- sparsity_pattern = qc.get("sparsity_pattern")
- if sparsity_pattern:
- return cls(
- pattern=sparsity_pattern,
- block_size=None,
- density=1.0 - qc.get("sparsity", 0.5),
- )
- return None
-
- pattern = sc.get("pattern", sc.get("sparsity_pattern", "unknown"))
- block_size = sc.get("block_size")
- if block_size and isinstance(block_size, list):
- block_size = tuple(block_size)
-
- return cls(
- pattern=pattern,
- block_size=block_size,
- density=sc.get("density", 1.0 - sc.get("sparsity", 0.0)),
- )
-
- def is_2_4_sparse(self) -> bool:
- """Check if this is 2:4 structured sparsity (Ampere+ TensorCore)."""
- return self.pattern == "2:4"
-
-
-# =============================================================================
-# Model Optimization Info (Issue #115)
-# =============================================================================
-
-
-@dataclass
-class ModelOptimizationInfo:
- """Combined optimization information for a model.
-
- Aggregates all optimization techniques applied to the model:
- - Quantization (FP8, QAT, etc.)
- - Pruning (structured, unstructured)
- - Sparsity (2:4, block)
- """
-
- fp8_config: FP8QuantConfig | None
- qat_config: QATQuantConfig | None
- pruning_config: PruningConfig | None
- sparsity_config: SparsityConfig | None
-
- @classmethod
- def from_config(cls, config: dict) -> ModelOptimizationInfo:
- """Parse all optimization configs from config.json."""
- return cls(
- fp8_config=FP8QuantConfig.from_config(config),
- qat_config=QATQuantConfig.from_config(config),
- pruning_config=PruningConfig.from_config(config),
- sparsity_config=SparsityConfig.from_config(config),
- )
-
- def has_any_optimization(self) -> bool:
- """Check if any optimization is applied."""
- return any(
- [
- self.fp8_config,
- self.qat_config,
- self.pruning_config,
- self.sparsity_config,
- ]
- )
-
- def summary(self) -> str:
- """Return a summary string of optimizations."""
- parts = []
- if self.fp8_config:
- parts.append(f"FP8({self.fp8_config.fmt})")
- if self.qat_config:
- parts.append(f"QAT({self.qat_config.quant_algo})")
- if self.pruning_config:
- parts.append(f"Pruned({self.pruning_config.pruning_method})")
- if self.sparsity_config:
- parts.append(f"Sparse({self.sparsity_config.pattern})")
- return ", ".join(parts) if parts else "None"
-
-
-# FP8 E4M3 to float32 lookup table (256 entries)
-# Format: 1 sign bit, 4 exponent bits, 3 mantissa bits
-# Special values: NaN (0x7F/0xFF), no infinity
-_FP8_E4M3_TO_F32_TABLE: np.ndarray | None = None
-
-
-def _get_fp8_e4m3_table() -> np.ndarray:
- """Build FP8 E4M3 to float32 conversion lookup table."""
- global _FP8_E4M3_TO_F32_TABLE
- if _FP8_E4M3_TO_F32_TABLE is not None:
- return _FP8_E4M3_TO_F32_TABLE
-
- table = np.zeros(256, dtype=np.float32)
- for i in range(256):
- # Extract components
- sign = (i >> 7) & 1
- exp = (i >> 3) & 0xF # 4 exponent bits
- mant = i & 0x7 # 3 mantissa bits
-
- if exp == 0xF and mant == 0x7:
- # NaN (0x7F and 0xFF)
- table[i] = np.nan
- elif exp == 0:
- # Subnormal (exponent = 0)
- # Value = (-1)^sign * 2^(-6) * (0.mantissa)
- value = (mant / 8.0) * (2.0**-6)
- table[i] = -value if sign else value
- else:
- # Normal
- # Value = (-1)^sign * 2^(exp-7) * (1.mantissa)
- value = (1.0 + mant / 8.0) * (2.0 ** (exp - 7))
- table[i] = -value if sign else value
-
- _FP8_E4M3_TO_F32_TABLE = table
- return table
-
-
-def dequantize_fp8_e4m3_block(
- fp8_bytes: np.ndarray,
- scale_inv: np.ndarray,
- block_size: tuple[int, int] = (128, 128),
-) -> np.ndarray:
- """Dequantize FP8 E4M3 weight with block-wise scaling.
-
- Args:
- fp8_bytes: Raw FP8 data as uint8 array, shape [H, W]
- scale_inv: Inverse scale factors, shape [H//block_h, W//block_w]
- block_size: Block size for quantization (default 128x128)
-
- Returns:
- Dequantized float32 array, shape [H, W]
- """
- # Convert FP8 bytes to float32 using lookup table
- table = _get_fp8_e4m3_table()
- f32 = table[fp8_bytes.ravel()].reshape(fp8_bytes.shape)
-
- # Apply block-wise scaling
- H, W = f32.shape
- block_h, block_w = block_size
-
- # Ensure scale_inv is float32 for computation
- if scale_inv.dtype != np.float32:
- # BF16 stored as uint16 -> convert to float32
- if scale_inv.dtype == np.uint16:
- scale_f32 = np.empty(scale_inv.shape, dtype=np.float32)
- scale_f32.view(np.uint32)[:] = scale_inv.astype(np.uint32) << 16
- else:
- scale_f32 = scale_inv.astype(np.float32)
- else:
- scale_f32 = scale_inv
-
- # Apply scaling per block using broadcasting
- num_blocks_h = H // block_h
- num_blocks_w = W // block_w
-
- # Reshape for vectorized block scaling
- f32_reshaped = f32.reshape(num_blocks_h, block_h, num_blocks_w, block_w)
- scale_expanded = scale_f32[:, np.newaxis, :, np.newaxis]
- f32_scaled = f32_reshaped * scale_expanded
- result = f32_scaled.reshape(H, W)
-
- return result
-
-
-def is_fp8_weight(tensor_name: str, tensor_names: list[str]) -> bool:
- """Check if a weight tensor has an FP8 scale tensor."""
- scale_name = tensor_name + "_scale_inv"
- return scale_name in tensor_names
-
-
-def load_fp8_weight_direct(
- st: SafeTensorsFile | ShardedSafeTensorsFile,
- weight_name: str,
- block_size: tuple[int, int] = (128, 128),
-) -> tuple[GPUArray, GPUArray]:
- """Load FP8 weight directly without dequantization.
-
- Returns:
- (weight_fp8, scale_inv) tuple:
- - weight_fp8: [out_features, in_features] as uint8
- - scale_inv: [out/block_h, in/block_w] as bf16
- """
- from pygpukit.core.factory import from_numpy
- from pygpukit.llm import Dtype
-
- # Load FP8 weight as uint8
- info = st.tensor_info(weight_name)
- data = st.tensor_bytes(weight_name)
- fp8_bytes = np.frombuffer(data, dtype=np.uint8).reshape(info.shape).copy()
- weight_fp8 = from_numpy(fp8_bytes)
-
- # Load scale_inv tensor
- scale_name = weight_name + "_scale_inv"
- scale_info = st.tensor_info(scale_name)
- scale_data = st.tensor_bytes(scale_name)
-
- # scale_inv is typically bfloat16
- if scale_info.dtype == Dtype.BFloat16:
- scale_inv = np.frombuffer(scale_data, dtype=np.uint16).reshape(scale_info.shape).copy()
- else:
- # Convert float32 to bfloat16
- scale_f32 = np.frombuffer(scale_data, dtype=np.float32).reshape(scale_info.shape)
- uint32_view = scale_f32.view(np.uint32)
- scale_inv = ((uint32_view + 0x7FFF + ((uint32_view >> 16) & 1)) >> 16).astype(np.uint16)
+# Re-export quantization configs and utilities from quant module
+from pygpukit.llm.quant import (
+ FP8QuantConfig,
+ ModelOptimizationInfo,
+ PruningConfig,
+ QATQuantConfig,
+ SparsityConfig,
+ load_fp8_weight_direct,
+)
- scale_inv_gpu = from_numpy(scale_inv)
+# Re-export repack function
+from pygpukit.llm.repack import repack_model_weights
- return weight_fp8, scale_inv_gpu
+if TYPE_CHECKING:
+ from pygpukit.llm.model import CausalTransformerModel
# =============================================================================
@@ -502,277 +124,6 @@ def load_mixtral_from_safetensors(
return load_model_from_safetensors(model_path, dtype=dtype, spec=MIXTRAL_SPEC)
-# =============================================================================
-# Model Weight Repacking
-# =============================================================================
-
-
-def repack_model_weights(model: CausalTransformerModel) -> None:
- """Repack all model weights into contiguous GPU memory.
-
- This fixes severe performance regression (7x slowdown) caused by
- fragmented GPU memory allocation during model loading. Weights
- allocated later end up in suboptimal memory regions.
-
- The repacking is done in two phases:
- 1. Convert ALL weights to numpy (freeing GPU memory)
- 2. Reallocate ALL weights fresh in contiguous memory
-
- Args:
- model: CausalTransformerModel to repack in-place
-
- Note:
- MoE models are currently skipped (not repacked) due to different
- weight structure. This will be addressed in a future update.
- """
- import gc
-
- # Skip repacking for MoE models (different weight structure)
- if model.blocks and isinstance(model.blocks[0].mlp, MoELayer):
- return
-
- # Phase 1: Collect all weights as numpy arrays
- numpy_cache: dict[int, dict] = {}
- dummy_arrays: list[GPUArray] = []
-
- # Embedding
- embed_np = model.embed_tokens.to_numpy()
- model.embed_tokens = None # type: ignore
-
- # Position embedding
- pos_embed_np = None
- if model.position_embed is not None:
- pos_embed_np = model.position_embed.to_numpy()
- model.position_embed = None
-
- # lm_head
- lm_head_np = None
- if model._lm_head is not None:
- lm_head_np = model._lm_head.to_numpy()
- model._lm_head = None
-
- # Final norm
- final_norm_weight_np = model.final_norm.weight.to_numpy()
- final_norm_bias_np = None
- if model.final_norm.bias is not None:
- final_norm_bias_np = model.final_norm.bias.to_numpy()
- model.final_norm.weight = None # type: ignore
- model.final_norm.bias = None
-
- # All blocks
- for i, block in enumerate(model.blocks):
- numpy_cache[i] = {}
-
- # Attention norms
- numpy_cache[i]["attn_norm_w"] = block.attn_norm.weight.to_numpy()
- numpy_cache[i]["attn_norm_b"] = (
- block.attn_norm.bias.to_numpy() if block.attn_norm.bias is not None else None
- )
- block.attn_norm.weight = None # type: ignore
- block.attn_norm.bias = None
-
- numpy_cache[i]["mlp_norm_w"] = block.mlp_norm.weight.to_numpy()
- numpy_cache[i]["mlp_norm_b"] = (
- block.mlp_norm.bias.to_numpy() if block.mlp_norm.bias is not None else None
- )
- block.mlp_norm.weight = None # type: ignore
- block.mlp_norm.bias = None
-
- # Attention projections
- attn = block.attn
- numpy_cache[i]["q_w"] = attn.q_proj.weight.to_numpy()
- numpy_cache[i]["q_b"] = (
- attn.q_proj.bias.to_numpy() if attn.q_proj.bias is not None else None
- )
- attn.q_proj.weight = None # type: ignore
- attn.q_proj.bias = None
- attn.q_proj._weight_t = None
-
- numpy_cache[i]["k_w"] = attn.k_proj.weight.to_numpy()
- numpy_cache[i]["k_b"] = (
- attn.k_proj.bias.to_numpy() if attn.k_proj.bias is not None else None
- )
- attn.k_proj.weight = None # type: ignore
- attn.k_proj.bias = None
- attn.k_proj._weight_t = None
-
- numpy_cache[i]["v_w"] = attn.v_proj.weight.to_numpy()
- numpy_cache[i]["v_b"] = (
- attn.v_proj.bias.to_numpy() if attn.v_proj.bias is not None else None
- )
- attn.v_proj.weight = None # type: ignore
- attn.v_proj.bias = None
- attn.v_proj._weight_t = None
-
- numpy_cache[i]["o_w"] = attn.o_proj.weight.to_numpy()
- numpy_cache[i]["o_b"] = (
- attn.o_proj.bias.to_numpy() if attn.o_proj.bias is not None else None
- )
- attn.o_proj.weight = None # type: ignore
- attn.o_proj.bias = None
- attn.o_proj._weight_t = None
-
- # QK norms
- if attn.q_norm is not None:
- numpy_cache[i]["q_norm_w"] = attn.q_norm.weight.to_numpy()
- numpy_cache[i]["q_norm_b"] = (
- attn.q_norm.bias.to_numpy() if attn.q_norm.bias is not None else None
- )
- attn.q_norm.weight = None # type: ignore
- attn.q_norm.bias = None
- if attn.k_norm is not None:
- numpy_cache[i]["k_norm_w"] = attn.k_norm.weight.to_numpy()
- numpy_cache[i]["k_norm_b"] = (
- attn.k_norm.bias.to_numpy() if attn.k_norm.bias is not None else None
- )
- attn.k_norm.weight = None # type: ignore
- attn.k_norm.bias = None
-
- # MLP projections
- mlp = block.mlp
- if mlp.activation == "gelu":
- numpy_cache[i]["fc1_w"] = mlp.fc1.weight.to_numpy()
- numpy_cache[i]["fc1_b"] = mlp.fc1.bias.to_numpy() if mlp.fc1.bias is not None else None
- mlp.fc1.weight = None # type: ignore
- mlp.fc1.bias = None
- mlp.fc1._weight_t = None
-
- numpy_cache[i]["fc2_w"] = mlp.fc2.weight.to_numpy()
- numpy_cache[i]["fc2_b"] = mlp.fc2.bias.to_numpy() if mlp.fc2.bias is not None else None
- mlp.fc2.weight = None # type: ignore
- mlp.fc2.bias = None
- mlp.fc2._weight_t = None
- else: # SwiGLU
- numpy_cache[i]["gate_w"] = mlp.gate_proj.weight.to_numpy()
- numpy_cache[i]["gate_b"] = (
- mlp.gate_proj.bias.to_numpy() if mlp.gate_proj.bias is not None else None
- )
- mlp.gate_proj.weight = None # type: ignore
- mlp.gate_proj.bias = None
- mlp.gate_proj._weight_t = None
-
- numpy_cache[i]["up_w"] = mlp.up_proj.weight.to_numpy()
- numpy_cache[i]["up_b"] = (
- mlp.up_proj.bias.to_numpy() if mlp.up_proj.bias is not None else None
- )
- mlp.up_proj.weight = None # type: ignore
- mlp.up_proj.bias = None
- mlp.up_proj._weight_t = None
-
- numpy_cache[i]["down_w"] = mlp.down_proj.weight.to_numpy()
- numpy_cache[i]["down_b"] = (
- mlp.down_proj.bias.to_numpy() if mlp.down_proj.bias is not None else None
- )
- mlp.down_proj.weight = None # type: ignore
- mlp.down_proj.bias = None
- mlp.down_proj._weight_t = None
-
- # Force garbage collection to free GPU memory
- gc.collect()
-
- # Allocate dummy arrays to fill the freed memory space
- dummy_size = 1024 * 1024 * 512 # 512M elements = 1GB for FP16
- try:
- for _ in range(16): # Allocate ~16GB of dummy memory
- dummy = from_numpy(np.zeros(dummy_size, dtype=np.float16))
- dummy_arrays.append(dummy)
- except Exception:
- pass # Continue with whatever dummy memory we could allocate
-
- # Phase 2: Reallocate all weights fresh (REVERSE order for memory optimization)
- for i in reversed(range(len(model.blocks))):
- block = model.blocks[i]
- cache = numpy_cache[i]
-
- # Attention norms
- block.attn_norm.weight = from_numpy(cache["attn_norm_w"])
- if cache["attn_norm_b"] is not None:
- block.attn_norm.bias = from_numpy(cache["attn_norm_b"])
-
- block.mlp_norm.weight = from_numpy(cache["mlp_norm_w"])
- if cache["mlp_norm_b"] is not None:
- block.mlp_norm.bias = from_numpy(cache["mlp_norm_b"])
-
- # Attention projections
- attn = block.attn
- attn.q_proj.weight = from_numpy(cache["q_w"])
- if cache["q_b"] is not None:
- attn.q_proj.bias = from_numpy(cache["q_b"])
-
- attn.k_proj.weight = from_numpy(cache["k_w"])
- if cache["k_b"] is not None:
- attn.k_proj.bias = from_numpy(cache["k_b"])
-
- attn.v_proj.weight = from_numpy(cache["v_w"])
- if cache["v_b"] is not None:
- attn.v_proj.bias = from_numpy(cache["v_b"])
-
- attn.o_proj.weight = from_numpy(cache["o_w"])
- if cache["o_b"] is not None:
- attn.o_proj.bias = from_numpy(cache["o_b"])
-
- # QK norms
- if "q_norm_w" in cache:
- attn.q_norm.weight = from_numpy(cache["q_norm_w"])
- if cache["q_norm_b"] is not None:
- attn.q_norm.bias = from_numpy(cache["q_norm_b"])
- if "k_norm_w" in cache:
- attn.k_norm.weight = from_numpy(cache["k_norm_w"])
- if cache["k_norm_b"] is not None:
- attn.k_norm.bias = from_numpy(cache["k_norm_b"])
-
- # MLP projections
- mlp = block.mlp
- if mlp.activation == "gelu":
- mlp.fc1.weight = from_numpy(cache["fc1_w"])
- if cache["fc1_b"] is not None:
- mlp.fc1.bias = from_numpy(cache["fc1_b"])
-
- mlp.fc2.weight = from_numpy(cache["fc2_w"])
- if cache["fc2_b"] is not None:
- mlp.fc2.bias = from_numpy(cache["fc2_b"])
- else: # SwiGLU
- mlp.gate_proj.weight = from_numpy(cache["gate_w"])
- if cache["gate_b"] is not None:
- mlp.gate_proj.bias = from_numpy(cache["gate_b"])
-
- mlp.up_proj.weight = from_numpy(cache["up_w"])
- if cache["up_b"] is not None:
- mlp.up_proj.bias = from_numpy(cache["up_b"])
-
- mlp.down_proj.weight = from_numpy(cache["down_w"])
- if cache["down_b"] is not None:
- mlp.down_proj.bias = from_numpy(cache["down_b"])
-
- # Clear this block's cache immediately
- del numpy_cache[i]
-
- # Final norm
- model.final_norm.weight = from_numpy(final_norm_weight_np)
- if final_norm_bias_np is not None:
- model.final_norm.bias = from_numpy(final_norm_bias_np)
-
- # lm_head
- if lm_head_np is not None:
- model._lm_head = from_numpy(lm_head_np)
-
- # Embedding and position embedding last
- model.embed_tokens = from_numpy(embed_np)
- del embed_np
-
- if pos_embed_np is not None:
- model.position_embed = from_numpy(pos_embed_np)
- del pos_embed_np
-
- # Clear any cached transposes
- if hasattr(model, "_lm_head_t_cache"):
- delattr(model, "_lm_head_t_cache")
-
- # Free dummy arrays
- del dummy_arrays
- gc.collect()
-
-
# =============================================================================
# Generic Model Loader using ModelSpec
# =============================================================================
@@ -806,8 +157,8 @@ def load_model_from_safetensors(
model = load_model_from_safetensors("/path/to/model.safetensors", spec=LLAMA_SPEC)
"""
# Import here to avoid circular import
- from pygpukit.llm import Dtype, load_safetensors
from pygpukit.llm.model import CausalTransformerModel
+ from pygpukit.llm.safetensors import Dtype, load_safetensors
st = load_safetensors(model_path)
@@ -1241,3 +592,21 @@ def expert_name(pattern: str, layer: int, expert: int) -> str:
if repack_weights:
repack_model_weights(model)
return model
+
+
+__all__ = [
+ # Main loaders
+ "load_model_from_safetensors",
+ "load_gpt2_from_safetensors",
+ "load_llama_from_safetensors",
+ "load_qwen3_from_safetensors",
+ "load_mixtral_from_safetensors",
+ # Weight repacking
+ "repack_model_weights",
+ # Quantization configs (re-exported)
+ "FP8QuantConfig",
+ "QATQuantConfig",
+ "PruningConfig",
+ "SparsityConfig",
+ "ModelOptimizationInfo",
+]
diff --git a/src/pygpukit/llm/model.py b/src/pygpukit/llm/model.py
index 4b35245..47d3153 100644
--- a/src/pygpukit/llm/model.py
+++ b/src/pygpukit/llm/model.py
@@ -1,1501 +1,35 @@
"""CausalTransformerModel implementation for PyGPUkit.
-Provides the unified Transformer runtime for GPT-2, LLaMA, and Qwen3 architectures.
-Model-specific behavior is controlled by the ModelSpec configuration.
-
-Key features:
-- Hybrid Attention: CPU for seq_len=1 (decode), GPU for prefill
-- GPU-native operations: RMSNorm, LayerNorm, SDPA, SiLU, GELU, RoPE
-- CUDA Graph support for zero-allocation decode
-- Speculative and Jacobi decoding modes
+This module re-exports from llm/models/ for backwards compatibility.
+See llm/models/causal.py for the actual implementation.
"""
from __future__ import annotations
-from collections.abc import Generator
-from typing import TYPE_CHECKING, Literal
-
-import numpy as np
-
-from pygpukit.core.array import GPUArray
-from pygpukit.core.factory import from_numpy
-
-# Import from refactored modules
-from pygpukit.llm.buffers import DecodeBuffers, PrefillBuffers
-from pygpukit.llm.config import ModelSpec, TransformerConfig
-from pygpukit.llm.layers import (
- MLP,
- Attention,
- Norm,
- TransformerBlock,
+# Re-export everything from models/
+from pygpukit.llm.models.causal import (
+ CausalSelfAttention,
+ CausalTransformerModel,
+ GPT2Model,
+ LayerNorm,
+ LlamaAttention,
+ LlamaBlock,
+ LlamaMLP,
+ LlamaModel,
+ RMSNorm,
)
-from pygpukit.llm.sampling import sample_token
-from pygpukit.ops.basic import (
- add,
- add_inplace,
- bias_add_inplace,
- copy_to,
- embedding_lookup,
- embedding_lookup_ptr,
- gelu,
- kv_cache_update_gqa,
- kv_cache_update_gqa_ptr,
- matmul,
- mul_inplace,
- repeat_interleave_axis1,
- reshape_copy,
- rmsnorm,
- rope_inplace,
- sample_token_gpu,
- sdpa_causal,
- sdpa_causal_fixed_cache,
- sdpa_causal_fixed_cache_ptr,
- silu,
- transpose,
- transpose_3d_021,
-)
-
-if TYPE_CHECKING:
- pass
-
-
-def _to_float32_logits(logits_np: np.ndarray) -> np.ndarray:
- """Convert logits to float32 for sampling.
-
- If logits are stored as uint16 (bfloat16 representation), convert them
- to float32. Otherwise return as-is.
- """
- if logits_np.dtype == np.uint16:
- # bfloat16 stored as uint16: convert to float32
- return (logits_np.astype(np.uint32) << 16).view(np.float32)
- return logits_np.astype(np.float32)
-
-
-# =============================================================================
-# Unified CausalTransformerModel
-# =============================================================================
-
-
-class CausalTransformerModel:
- """Unified causal transformer model.
-
- The single runtime model for all architectures (GPT-2, LLaMA, Qwen3).
- Model-specific behavior is controlled by the spec attribute.
- """
-
- # Type hints for dynamically added attributes
- _batch_decode_buffers: DecodeBuffers | None
- _batch_token_ids_np: np.ndarray
-
- def __init__(
- self,
- config: TransformerConfig,
- embed_tokens: GPUArray,
- blocks: list[TransformerBlock],
- final_norm: Norm,
- lm_head: GPUArray | None = None,
- position_embed: GPUArray | None = None, # For GPT-2 style
- spec: ModelSpec | None = None,
- ):
- self.config = config
- self.embed_tokens = embed_tokens
- self.blocks = blocks
- self.final_norm = final_norm
- self._lm_head = lm_head
- self.position_embed = position_embed
- self.spec = spec
-
- def __call__(
- self,
- input_ids: list[int],
- position_ids: list[int] | None = None,
- past_key_values: list[tuple | None] | None = None,
- use_cache: bool = False,
- ) -> tuple[GPUArray, list[tuple | None] | None]:
- """Forward pass.
-
- Args:
- input_ids: Token IDs [seq_len]
- position_ids: Position IDs (auto-generated if None)
- past_key_values: List of (k, v) tuples per layer
- use_cache: Whether to return KV cache
-
- Returns:
- Tuple of (hidden_states, present_key_values)
- """
- seq_len = len(input_ids)
-
- if position_ids is None:
- if past_key_values is not None and past_key_values[0] is not None:
- past_len = past_key_values[0][0].shape[0]
- position_ids = list(range(past_len, past_len + seq_len))
- else:
- position_ids = list(range(seq_len))
-
- # Token embeddings (cache numpy array to avoid repeated GPU->CPU transfer)
- if not hasattr(self, "_embed_np_cache"):
- self._embed_np_cache = self.embed_tokens.to_numpy()
- hidden_np = self._embed_np_cache[input_ids]
-
- # Add position embeddings (GPT-2 style)
- if self.position_embed is not None:
- if not hasattr(self, "_pos_embed_np_cache"):
- self._pos_embed_np_cache = self.position_embed.to_numpy()
- hidden_np = hidden_np + self._pos_embed_np_cache[position_ids]
-
- hidden: GPUArray = from_numpy(hidden_np.astype(self._embed_np_cache.dtype))
-
- # Transformer blocks
- present_key_values = []
- for i, block in enumerate(self.blocks):
- past_kv = past_key_values[i] if past_key_values else None
- hidden, present_kv = block(hidden, position_ids, past_kv, use_cache)
- present_key_values.append(present_kv)
-
- # Final norm
- hidden = self.final_norm(hidden)
-
- if use_cache:
- return hidden, present_key_values
- return hidden, None
-
- @property
- def lm_head(self) -> GPUArray | None:
- """LM head weights (for backward compatibility)."""
- return self._lm_head
-
- def get_logits(self, hidden: GPUArray) -> GPUArray:
- """Compute logits from hidden states on GPU."""
- # Cache transposed lm_head to avoid repeated transpose
- if not hasattr(self, "_lm_head_t_cache"):
- lm_head = self._lm_head if self._lm_head is not None else self.embed_tokens
- self._lm_head_t_cache = transpose(lm_head)
-
- # GPU matmul: hidden @ lm_head.T
- # hidden: [seq_len, hidden_size], lm_head: [vocab_size, hidden_size]
- # Result: [seq_len, vocab_size]
- return matmul(hidden, self._lm_head_t_cache)
-
- def generate(
- self,
- input_ids: list[int],
- max_new_tokens: int = 20,
- temperature: float = 1.0,
- top_k: int = 50,
- top_p: float = 0.9,
- eos_token_id: int | None = None,
- use_cache: bool = True,
- gpu_sampling: bool = False,
- ) -> list[int]:
- """Generate tokens autoregressively.
-
- Args:
- input_ids: Initial token IDs
- max_new_tokens: Maximum new tokens to generate
- temperature: Sampling temperature
- top_k: Top-k filtering
- top_p: Nucleus sampling threshold
- eos_token_id: Stop at this token
- use_cache: Use KV cache
- gpu_sampling: Use GPU-based sampling (avoids full logits D2H transfer)
-
- Returns:
- List of all token IDs (input + generated)
- """
- tokens = list(input_ids)
- past_key_values = None
-
- if use_cache:
- # Prefill
- hidden, past_key_values = self(tokens, use_cache=True)
- logits = self.get_logits(hidden)
-
- if gpu_sampling:
- # GPU sampling: only transfer 1 int instead of full vocab logits
- next_token = sample_token_gpu(logits[-1], temperature, top_k, top_p)
- else:
- last_logits = _to_float32_logits(logits.to_numpy()[-1])
- next_token = sample_token(last_logits, temperature, top_k, top_p)
- tokens.append(next_token)
-
- if eos_token_id is not None and next_token == eos_token_id:
- return tokens
-
- # Decode
- for _ in range(max_new_tokens - 1):
- hidden, past_key_values = self(
- [next_token], past_key_values=past_key_values, use_cache=True
- )
- logits = self.get_logits(hidden)
-
- if gpu_sampling:
- next_token = sample_token_gpu(logits[-1], temperature, top_k, top_p)
- else:
- last_logits = _to_float32_logits(logits.to_numpy()[-1])
- next_token = sample_token(last_logits, temperature, top_k, top_p)
- tokens.append(next_token)
-
- if eos_token_id is not None and next_token == eos_token_id:
- break
- else:
- for _ in range(max_new_tokens):
- hidden, _ = self(tokens, use_cache=False)
- logits = self.get_logits(hidden)
-
- if gpu_sampling:
- next_token = sample_token_gpu(logits[-1], temperature, top_k, top_p)
- else:
- last_logits = _to_float32_logits(logits.to_numpy()[-1])
- next_token = sample_token(last_logits, temperature, top_k, top_p)
- tokens.append(next_token)
-
- if eos_token_id is not None and next_token == eos_token_id:
- break
-
- return tokens
-
- def generate_stream(
- self,
- input_ids: list[int],
- max_new_tokens: int = 20,
- temperature: float = 1.0,
- top_k: int = 50,
- top_p: float = 0.9,
- eos_token_id: int | None = None,
- gpu_sampling: bool = False,
- ) -> Generator[int, None, None]:
- """Generate tokens autoregressively with streaming.
-
- Yields tokens one at a time as they are generated, enabling
- real-time text display in chat applications.
-
- Args:
- input_ids: Initial token IDs
- max_new_tokens: Maximum new tokens to generate
- temperature: Sampling temperature
- top_k: Top-k filtering
- top_p: Nucleus sampling threshold
- eos_token_id: Stop at this token
- gpu_sampling: Use GPU-based sampling (avoids full logits D2H transfer)
-
- Yields:
- Generated token IDs one at a time
-
- Example:
- >>> for token_id in model.generate_stream(input_ids, max_new_tokens=50):
- ... token_str = tokenizer.decode([token_id])
- ... print(token_str, end="", flush=True)
- """
- past_key_values = None
-
- # Prefill
- hidden, past_key_values = self(input_ids, use_cache=True)
- logits = self.get_logits(hidden)
-
- if gpu_sampling:
- next_token = sample_token_gpu(logits[-1], temperature, top_k, top_p)
- else:
- last_logits = _to_float32_logits(logits.to_numpy()[-1])
- next_token = sample_token(last_logits, temperature, top_k, top_p)
-
- yield next_token
-
- if eos_token_id is not None and next_token == eos_token_id:
- return
-
- # Decode
- for _ in range(max_new_tokens - 1):
- hidden, past_key_values = self(
- [next_token], past_key_values=past_key_values, use_cache=True
- )
- logits = self.get_logits(hidden)
-
- if gpu_sampling:
- next_token = sample_token_gpu(logits[-1], temperature, top_k, top_p)
- else:
- last_logits = _to_float32_logits(logits.to_numpy()[-1])
- next_token = sample_token(last_logits, temperature, top_k, top_p)
-
- yield next_token
-
- if eos_token_id is not None and next_token == eos_token_id:
- return
-
- def _decode_step_zero_alloc(
- self,
- token_id: int,
- position: int,
- context_len: int,
- buffers: DecodeBuffers,
- ) -> GPUArray:
- """Single decode step with zero memory allocations.
-
- Uses pre-allocated DecodeBuffers for all intermediate computations.
- All operations write to pre-allocated buffers, no new GPU memory is allocated.
-
- Args:
- token_id: Current token ID
- position: Position in sequence
- context_len: Total context length
- buffers: Pre-allocated decode buffers
-
- Returns:
- Hidden states [1, hidden_size]
- """
- # Get token embedding directly to hidden (no copy needed)
- embedding_lookup(self.embed_tokens, buffers.hidden, token_id)
-
- # Transformer blocks with fixed cache
- for block in self.blocks:
- # Pre-norm: hidden -> norm_out
- rmsnorm(
- buffers.hidden, block.attn_norm.weight, block.attn_norm.eps, out=buffers.norm_out
- )
-
- # Save residual
- copy_to(buffers.hidden, buffers.residual)
-
- # Attention with fixed cache (writes to buffers.hidden)
- self._attention_forward_zero_alloc(
- block.attn, buffers.norm_out, position, context_len, buffers
- )
-
- # Add residual: hidden = residual + hidden
- add_inplace(buffers.hidden, buffers.residual)
-
- # MLP pre-norm
- copy_to(buffers.hidden, buffers.residual)
- rmsnorm(buffers.hidden, block.mlp_norm.weight, block.mlp_norm.eps, out=buffers.norm_out)
-
- # MLP forward (SwiGLU)
- self._mlp_forward_zero_alloc(block.mlp, buffers.norm_out, buffers)
-
- # Add residual
- add_inplace(buffers.hidden, buffers.residual)
-
- # Final norm
- rmsnorm(buffers.hidden, self.final_norm.weight, self.final_norm.eps, out=buffers.norm_out)
- copy_to(buffers.norm_out, buffers.hidden)
-
- return buffers.hidden
-
- def _attention_forward_zero_alloc(
- self,
- attn: Attention,
- x: GPUArray,
- position: int,
- context_len: int,
- buffers: DecodeBuffers,
- use_position_ptr: bool = False,
- use_context_len_ptr: bool = False,
- max_kv_len: int | None = None,
- ) -> None:
- """Attention forward pass with zero allocations.
-
- Result is written to buffers.hidden.
-
- Args:
- use_position_ptr: If True, read position from buffers.position_buf
- (for CUDA Graph replay without recapture).
- use_context_len_ptr: If True, read context_len from buffers.context_len_buf
- (for CUDA Graph replay without recapture).
- max_kv_len: Maximum KV length for CUDA Graph shared memory allocation.
- Required if use_context_len_ptr=True.
- """
- # Fused QKV projection (1 matmul replaces 3, then zero-copy narrow views)
- # This is 4x faster for M=1 with cuBLASLt due to reduced kernel launch overhead
- attn.qkv_proj(x, out=buffers.qkv_proj_out)
-
- # Apply biases (fused projection has no bias)
- if attn.q_proj.bias is not None:
- bias_add_inplace(buffers.q_view, attn.q_proj.bias)
- if attn.k_proj.bias is not None:
- bias_add_inplace(buffers.k_view, attn.k_proj.bias)
- if attn.v_proj.bias is not None:
- bias_add_inplace(buffers.v_view, attn.v_proj.bias)
-
- # Reshape narrow views to 3D using pre-allocated buffers
- # q_view, k_view, v_view are pre-created zero-copy views of qkv_proj_out
- reshape_copy(buffers.q_view, (1, attn.num_heads, attn.head_dim), out=buffers.q)
- reshape_copy(buffers.k_view, (1, attn.num_kv_heads, attn.head_dim), out=buffers.k)
- reshape_copy(buffers.v_view, (1, attn.num_kv_heads, attn.head_dim), out=buffers.v)
- q, k, v = buffers.q, buffers.k, buffers.v
-
- # QK Norm (Qwen3) - zero allocation using pre-allocated buffers
- if attn.q_norm is not None and buffers.q_2d is not None and buffers.q_flat is not None:
- # Reshape q [1,H,D] -> q_flat [H,D], apply norm, reshape back to q [1,H,D]
- reshape_copy(q, (attn.num_heads, attn.head_dim), out=buffers.q_flat)
- rmsnorm(buffers.q_flat, attn.q_norm.weight, attn.q_norm.eps, out=buffers.q_2d)
- reshape_copy(buffers.q_2d, (1, attn.num_heads, attn.head_dim), out=buffers.q)
- q = buffers.q
- if attn.k_norm is not None and buffers.k_2d is not None and buffers.k_flat is not None:
- # Reshape k [1,H,D] -> k_flat [H,D], apply norm, reshape back to k [1,H,D]
- reshape_copy(k, (attn.num_kv_heads, attn.head_dim), out=buffers.k_flat)
- rmsnorm(buffers.k_flat, attn.k_norm.weight, attn.k_norm.eps, out=buffers.k_2d)
- reshape_copy(buffers.k_2d, (1, attn.num_kv_heads, attn.head_dim), out=buffers.k)
- k = buffers.k
-
- # Apply RoPE using pre-computed GPU tables (zero allocation)
- if self.config.use_rope and hasattr(self, "_rope_cos_gpu"):
- # Extract single row from pre-computed tables using GPU kernel
- if use_position_ptr and buffers.position_buf is not None:
- # Use _ptr variants for CUDA Graph replay
- embedding_lookup_ptr(self._rope_cos_gpu, buffers.cos, buffers.position_buf)
- embedding_lookup_ptr(self._rope_sin_gpu, buffers.sin, buffers.position_buf)
- else:
- embedding_lookup(self._rope_cos_gpu, buffers.cos, position)
- embedding_lookup(self._rope_sin_gpu, buffers.sin, position)
- # buffers.cos/sin are already [1, head_dim] - use directly
- rope_inplace(q, k, buffers.cos, buffers.sin)
-
- # Update KV cache at position (GQA-expanded, transposed)
- if use_position_ptr and buffers.position_buf is not None:
- # Use _ptr variants for CUDA Graph replay
- kv_cache_update_gqa_ptr(k, attn._k_cache, attn.num_heads, buffers.position_buf)
- kv_cache_update_gqa_ptr(v, attn._v_cache, attn.num_heads, buffers.position_buf)
- else:
- kv_cache_update_gqa(k, attn._k_cache, attn.num_heads, position)
- kv_cache_update_gqa(v, attn._v_cache, attn.num_heads, position)
-
- # Transpose Q for SDPA: [1, num_heads, head_dim] -> [num_heads, 1, head_dim]
- transpose_3d_021(q, out=buffers.q_t)
-
- # SDPA with fixed cache
- if use_context_len_ptr and buffers.context_len_buf is not None:
- # Use pointer-based SDPA for CUDA Graph replay
- assert max_kv_len is not None, "max_kv_len required for CUDA Graph mode"
- sdpa_causal_fixed_cache_ptr(
- buffers.q_t,
- attn._k_cache,
- attn._v_cache,
- buffers.attn_out,
- buffers.context_len_buf,
- max_kv_len,
- )
- else:
- sdpa_causal_fixed_cache(
- buffers.q_t, attn._k_cache, attn._v_cache, buffers.attn_out, context_len
- )
-
- # Transpose output: [num_heads, 1, head_dim] -> [1, num_heads, head_dim]
- transpose_3d_021(buffers.attn_out, out=buffers.q) # Reuse q buffer for transposed output
-
- # Reshape to 2D: [1, hidden_size] - reuse q_proj_out buffer
- reshape_copy(buffers.q, (1, attn.num_heads * attn.head_dim), out=buffers.q_proj_out)
-
- # Output projection directly to hidden (eliminates copy)
- attn.o_proj(buffers.q_proj_out, out=buffers.hidden)
-
- def _mlp_forward_zero_alloc(
- self,
- mlp: MLP,
- x: GPUArray,
- buffers: DecodeBuffers,
- ) -> None:
- """MLP forward pass with zero allocations (SwiGLU).
-
- Result is written to buffers.hidden.
- """
- if mlp.activation == "silu":
- # Non-fused SwiGLU (2 separate matmuls) - for debugging
- mlp.gate_proj(x, out=buffers.mlp_gate)
- silu(buffers.mlp_gate, out=buffers.mlp_gate)
-
- mlp.up_proj(x, out=buffers.mlp_up)
-
- mul_inplace(buffers.mlp_gate, buffers.mlp_up)
-
- mlp.down_proj(buffers.mlp_gate, out=buffers.hidden)
- else:
- # GELU path (GPT-2) - still has allocations, rarely used
- fc1_out = mlp.fc1(x)
- gelu_out = gelu(fc1_out)
- fc2_out = mlp.fc2(gelu_out)
- copy_to(fc2_out, buffers.hidden)
-
- def _mlp_forward_batch_zero_alloc(
- self,
- mlp: MLP,
- x: GPUArray,
- buffers: DecodeBuffers,
- out: GPUArray,
- ) -> None:
- """Batch MLP forward pass with zero allocations (SwiGLU).
-
- Uses fused gate_up projection for efficiency.
-
- Args:
- mlp: MLP module
- x: Input tensor [seq_len, hidden_size]
- buffers: Pre-allocated decode buffers
- out: Output buffer [seq_len, hidden_size] to write result
- """
- seq_len = x.shape[0]
-
- if mlp.activation == "silu":
- # Fused gate_up projection
- gate_up_out = buffers.gate_up_out_batch.slice_rows(seq_len)
- mlp.gate_up_proj(x, out=gate_up_out)
-
- # Split into gate and up using narrow
- intermediate_size = mlp.intermediate_size
- gate = gate_up_out.narrow(0, intermediate_size) # [seq_len, intermediate_size]
- up = gate_up_out.narrow(intermediate_size, intermediate_size)
-
- # SiLU in-place on gate
- silu(gate, out=gate)
-
- # Multiply gate * up in-place
- mul_inplace(gate, up)
-
- # Down projection to output buffer
- mlp.down_proj(gate, out=out)
- else:
- # GELU path - still has allocations (rarely used)
- fc1_out = mlp.fc1(x)
- gelu_out = gelu(fc1_out)
- mlp.fc2(gelu_out, out=out)
-
- def _prefill_with_buffers(
- self,
- input_ids: list[int],
- buffers: PrefillBuffers,
- use_cache: bool = True,
- ) -> tuple[GPUArray, list[tuple | None] | None]:
- """Prefill forward pass with reduced allocations using pre-allocated buffers.
-
- Uses PrefillBuffers for projection outputs, attention intermediates, and MLP
- to reduce memory allocations during prefill. Full zero-allocation requires
- kernel-level support for partial buffer operations.
-
- Args:
- input_ids: Token IDs [seq_len]
- buffers: Pre-allocated prefill buffers
- use_cache: Whether to return KV cache
-
- Returns:
- Tuple of (hidden_states, present_key_values)
- """
- seq_len = len(input_ids)
- assert seq_len <= buffers.max_seq_len, (
- f"seq_len {seq_len} > max_seq_len {buffers.max_seq_len}"
- )
-
- position_ids = list(range(seq_len))
-
- # Token embeddings - copy to pre-allocated buffer
- if not hasattr(self, "_embed_np_cache"):
- self._embed_np_cache = self.embed_tokens.to_numpy()
- hidden_np = self._embed_np_cache[input_ids]
-
- # Add position embeddings (GPT-2 style)
- if self.position_embed is not None:
- if not hasattr(self, "_pos_embed_np_cache"):
- self._pos_embed_np_cache = self.position_embed.to_numpy()
- hidden_np = hidden_np + self._pos_embed_np_cache[position_ids]
-
- # Copy to pre-allocated hidden buffer
- hidden = from_numpy(hidden_np.astype(self._embed_np_cache.dtype))
- copy_to(hidden, buffers.hidden)
-
- # Transformer blocks with buffer reuse
- present_key_values = []
- for block in self.blocks:
- # Process using buffers where possible
- hidden, present_kv = self._prefill_block_with_buffers(
- block, buffers.hidden, position_ids, buffers, use_cache
- )
- present_key_values.append(present_kv)
-
- # Final norm - reuse norm_out buffer
- rmsnorm(buffers.hidden, self.final_norm.weight, self.final_norm.eps, out=buffers.norm_out)
- copy_to(buffers.norm_out, buffers.hidden)
-
- if use_cache:
- return buffers.hidden, present_key_values
- return buffers.hidden, None
-
- def _prefill_block_with_buffers(
- self,
- block: TransformerBlock,
- hidden: GPUArray,
- position_ids: list[int],
- buffers: PrefillBuffers,
- use_cache: bool,
- ) -> tuple[GPUArray, tuple | None]:
- """Single transformer block forward with buffer reuse.
-
- Args:
- block: TransformerBlock to process
- hidden: Input hidden states [seq_len, hidden_size]
- position_ids: Position IDs for RoPE
- buffers: Pre-allocated prefill buffers
- use_cache: Whether to return KV cache
-
- Returns:
- Tuple of (output_hidden, present_kv)
- """
- # Attention block
- # Pre-norm -> norm_out
- rmsnorm(hidden, block.attn_norm.weight, block.attn_norm.eps, out=buffers.norm_out)
-
- # Save residual
- copy_to(hidden, buffers.residual)
-
- # Attention forward with buffers
- attn_out, present_kv = self._prefill_attention_with_buffers(
- block.attn, buffers.norm_out, position_ids, buffers, use_cache
- )
-
- # Residual connection: hidden = residual + attn_out
- add_inplace(attn_out, buffers.residual)
- copy_to(attn_out, buffers.hidden)
-
- # MLP block
- # Pre-norm
- copy_to(buffers.hidden, buffers.residual)
- rmsnorm(buffers.hidden, block.mlp_norm.weight, block.mlp_norm.eps, out=buffers.norm_out)
-
- # MLP forward with buffers
- self._prefill_mlp_with_buffers(block.mlp, buffers.norm_out, buffers)
-
- # Residual connection
- add_inplace(buffers.hidden, buffers.residual)
-
- return buffers.hidden, present_kv
-
- def _prefill_attention_with_buffers(
- self,
- attn: Attention,
- x: GPUArray,
- position_ids: list[int],
- buffers: PrefillBuffers,
- use_cache: bool,
- ) -> tuple[GPUArray, tuple | None]:
- """Attention forward pass with buffer reuse during prefill.
-
- Args:
- attn: Attention layer
- x: Input [seq_len, hidden_size]
- position_ids: Position IDs for RoPE
- buffers: Pre-allocated prefill buffers
- use_cache: Whether to return KV cache
-
- Returns:
- Tuple of (output, present_kv)
- """
- seq_len = x.shape[0]
-
- # Project Q, K, V using pre-allocated buffers
- attn.q_proj(x, out=buffers.q_proj_out)
- attn.k_proj(x, out=buffers.k_proj_out)
- attn.v_proj(x, out=buffers.v_proj_out)
-
- # Reshape to 3D
- reshape_copy(buffers.q_proj_out, out=buffers.q)
- reshape_copy(buffers.k_proj_out, out=buffers.k)
- reshape_copy(buffers.v_proj_out, out=buffers.v)
- q, k, v = buffers.q, buffers.k, buffers.v
-
- # QK Norm (Qwen3 style)
- if attn.q_norm is not None and buffers.q_2d is not None:
- q_2d = reshape_copy(q, (seq_len * attn.num_heads, attn.head_dim))
- q_2d = attn.q_norm(q_2d)
- q = reshape_copy(q_2d, (seq_len, attn.num_heads, attn.head_dim))
- if attn.k_norm is not None and buffers.k_2d is not None:
- k_2d = reshape_copy(k, (seq_len * attn.num_kv_heads, attn.head_dim))
- k_2d = attn.k_norm(k_2d)
- k = reshape_copy(k_2d, (seq_len, attn.num_kv_heads, attn.head_dim))
-
- # Apply RoPE
- if self.config.use_rope and attn._cos is not None and attn._sin is not None:
- # Use Attention's precomputed cos/sin tables
- q_dtype = q.dtype
- if q_dtype == "float16":
- cos = from_numpy(attn._cos[position_ids].astype(np.float16))
- sin = from_numpy(attn._sin[position_ids].astype(np.float16))
- elif q_dtype == "bfloat16":
- # Fall back to float32 computation for bfloat16
- cos = from_numpy(attn._cos[position_ids].astype(np.float32))
- sin = from_numpy(attn._sin[position_ids].astype(np.float32))
- else:
- # FP32 path
- cos = from_numpy(attn._cos[position_ids].astype(np.float32))
- sin = from_numpy(attn._sin[position_ids].astype(np.float32))
- # Apply RoPE in-place (FP32 and FP16 have native kernel support)
- if q_dtype in ("float32", "float16"):
- rope_inplace(q, k, cos, sin)
-
- # Store for KV cache - MUST copy since buffers.k/v are reused across layers
- if use_cache:
- # Create copies of K, V to avoid aliasing
- # (shared buffers get overwritten by later layers)
- k_copy = reshape_copy(k, k.shape)
- v_copy = reshape_copy(v, v.shape)
- present_kv = (k_copy, v_copy)
- else:
- present_kv = None
-
- # Expand for GQA
- if attn.num_kv_groups > 1:
- k_expanded = repeat_interleave_axis1(k, attn.num_kv_groups)
- v_expanded = repeat_interleave_axis1(v, attn.num_kv_groups)
- else:
- k_expanded = k
- v_expanded = v
-
- # Transpose for SDPA: [seq, heads, dim] -> [heads, seq, dim]
- transpose_3d_021(q, out=buffers.q_t)
- k_t = transpose_3d_021(k_expanded) # Can't use buffer due to GQA expansion
- v_t = transpose_3d_021(v_expanded)
-
- # SDPA with causal mask
- sdpa_causal(buffers.q_t, k_t, v_t, out=buffers.attn_out)
-
- # Transpose back and reshape
- transpose_3d_021(buffers.attn_out, out=buffers.attn_out_t)
- reshape_copy(buffers.attn_out_t, out=buffers.attn_out_2d)
-
- # Output projection
- attn.o_proj(buffers.attn_out_2d, out=buffers.o_proj_out)
-
- return buffers.o_proj_out, present_kv
-
- def _prefill_mlp_with_buffers(
- self,
- mlp: MLP,
- x: GPUArray,
- buffers: PrefillBuffers,
- ) -> None:
- """MLP forward pass with buffer reuse during prefill.
-
- Result is written to buffers.hidden.
-
- Args:
- mlp: MLP layer
- x: Input [seq_len, hidden_size]
- buffers: Pre-allocated prefill buffers
- """
- if mlp.activation == "silu":
- # SwiGLU: gate_proj -> SiLU -> * up_proj -> down_proj
- mlp.gate_proj(x, out=buffers.mlp_gate)
- silu(buffers.mlp_gate, out=buffers.mlp_gate)
-
- mlp.up_proj(x, out=buffers.mlp_up)
-
- # Element-wise multiply in-place
- mul_inplace(buffers.mlp_gate, buffers.mlp_up)
-
- # Down projection
- mlp.down_proj(buffers.mlp_gate, out=buffers.mlp_down)
- copy_to(buffers.mlp_down, buffers.hidden)
- else:
- # GELU path (GPT-2)
- fc1_out = mlp.fc1(x)
- gelu_out = gelu(fc1_out)
- fc2_out = mlp.fc2(gelu_out)
- copy_to(fc2_out, buffers.hidden)
-
- def _decode_step_fixed_cache(
- self,
- token_id: int,
- position: int,
- context_len: int,
- ) -> GPUArray:
- """Single decode step using fixed-length KV cache (legacy, with allocations).
-
- Args:
- token_id: Current token ID
- position: Position in sequence
- context_len: Total context length
-
- Returns:
- Hidden states [1, hidden_size]
- """
- # Get token embedding
- if not hasattr(self, "_embed_np_cache"):
- self._embed_np_cache = self.embed_tokens.to_numpy()
- hidden_np = self._embed_np_cache[token_id : token_id + 1]
- hidden = from_numpy(hidden_np.astype(self._embed_np_cache.dtype))
-
- # Transformer blocks with fixed cache
- for block in self.blocks:
- # Pre-norm
- residual = hidden
- hidden = block.attn_norm(hidden)
-
- # Attention with fixed cache
- hidden = block.attn.forward_fixed_cache(hidden, position, context_len)
- hidden = add(residual, hidden)
-
- # MLP
- residual = hidden
- hidden = block.mlp_norm(hidden)
- hidden = block.mlp(hidden)
- hidden = add(residual, hidden)
-
- # Final norm
- hidden = self.final_norm(hidden)
-
- return hidden
-
- def _decode_step_fixed_cache_batch(
- self,
- token_ids: list[int],
- start_position: int,
- context_len: int,
- ) -> GPUArray:
- """Batch decode step using fixed-length KV cache.
-
- Processes multiple tokens at once for speculative decoding verification.
-
- Args:
- token_ids: List of token IDs to decode [seq_len tokens]
- start_position: Starting position in sequence (first token's position)
- context_len: Total context length after adding this batch
- (should equal start_position + len(token_ids))
-
- Returns:
- Hidden states [seq_len, hidden_size]
- """
- # Dispatch to optimized single-token path for M=1
- if len(token_ids) == 1:
- return self._decode_step_fixed_cache(token_ids[0], start_position, context_len)
-
- # M > 1: Batch decode path
- # Get token embeddings for batch
- if not hasattr(self, "_embed_np_cache"):
- self._embed_np_cache = self.embed_tokens.to_numpy()
- hidden_np = self._embed_np_cache[token_ids] # [seq_len, hidden_size]
- hidden = from_numpy(hidden_np.astype(self._embed_np_cache.dtype))
-
- # Transformer blocks with fixed cache (batch)
- for block in self.blocks:
- # Pre-norm
- residual = hidden
- hidden = block.attn_norm(hidden)
-
- # Attention with fixed cache (batch)
- hidden = block.attn.forward_fixed_cache_batch(hidden, start_position, context_len)
- hidden = add(residual, hidden)
-
- # MLP
- residual = hidden
- hidden = block.mlp_norm(hidden)
- hidden = block.mlp(hidden)
- hidden = add(residual, hidden)
-
- # Final norm
- hidden = self.final_norm(hidden)
-
- return hidden
-
- def _decode_step_fixed_cache_batch_zero_alloc(
- self,
- token_ids: list[int],
- start_position: int,
- context_len: int,
- buffers: DecodeBuffers,
- ) -> GPUArray:
- """Batch decode step using pre-allocated buffers (zero-allocation).
-
- This function is designed to be CUDA Graph capture compatible.
- All intermediate buffers are pre-allocated in DecodeBuffers.
-
- Args:
- token_ids: List of token IDs to decode [seq_len tokens]
- start_position: Starting position in sequence (first token's position)
- context_len: Total context length after adding this batch
- buffers: Pre-allocated batch decode buffers
-
- Returns:
- Hidden states [seq_len, hidden_size] (view into buffers.hidden_batch)
-
- Note:
- Requires buffers.max_batch_size > 0 and len(token_ids) <= max_batch_size.
- TODO: CUDA Graph capture can be added once this path is validated.
- """
- seq_len = len(token_ids)
-
- if buffers.max_batch_size == 0:
- raise RuntimeError(
- "Batch buffers not allocated. Call DecodeBuffers.allocate(..., max_batch_size=8)"
- )
- if seq_len > buffers.max_batch_size:
- raise ValueError(
- f"seq_len ({seq_len}) exceeds max_batch_size ({buffers.max_batch_size})"
- )
-
- # Get embeddings (still uses numpy - small one-time cost)
- if not hasattr(self, "_embed_np_cache"):
- self._embed_np_cache = self.embed_tokens.to_numpy()
- hidden_np = self._embed_np_cache[token_ids] # [seq_len, hidden_size]
-
- # Copy to batch hidden buffer
- assert buffers.hidden_batch is not None
- buffers.hidden_batch._get_native().copy_from_numpy(
- hidden_np.astype(self._embed_np_cache.dtype)
- )
-
- # Use slice_rows for actual seq_len (logical batch size)
- # slice_rows creates a zero-copy view of the first N rows
- hidden = buffers.hidden_batch.slice_rows(seq_len)
- residual_buf = (
- buffers.residual_batch.slice_rows(seq_len) if buffers.residual_batch else None
- )
- norm_out_buf = (
- buffers.norm_out_batch.slice_rows(seq_len) if buffers.norm_out_batch else None
- )
-
- # Transformer blocks
- for block in self.blocks:
- # Pre-norm: attn_norm(hidden) -> norm_out
- if norm_out_buf is not None:
- rmsnorm(hidden, block.attn_norm.weight, block.attn_norm.eps, out=norm_out_buf)
- else:
- norm_out_buf = block.attn_norm(hidden)
-
- # Save residual
- if residual_buf is not None:
- copy_to(hidden, residual_buf)
- else:
- residual_buf = hidden
-
- # Attention with fixed cache (batch) - uses existing path for now
- # TODO: Add forward_fixed_cache_batch_zero_alloc to Attention class
- attn_out = block.attn.forward_fixed_cache_batch(
- norm_out_buf, start_position, context_len
- )
-
- # Residual connection: hidden = residual + attn_out
- add_inplace(residual_buf, attn_out)
- hidden = residual_buf
-
- # MLP norm
- if norm_out_buf is not None:
- rmsnorm(hidden, block.mlp_norm.weight, block.mlp_norm.eps, out=norm_out_buf)
- else:
- norm_out_buf = block.mlp_norm(hidden)
-
- # Save residual for MLP
- if residual_buf is not hidden:
- copy_to(hidden, residual_buf)
-
- # MLP - uses existing path for now
- # TODO: Add zero-alloc MLP path
- mlp_out = block.mlp(norm_out_buf)
-
- # Residual connection
- add_inplace(residual_buf, mlp_out)
- hidden = residual_buf
-
- # Final norm
- if norm_out_buf is not None:
- rmsnorm(hidden, self.final_norm.weight, self.final_norm.eps, out=norm_out_buf)
- return norm_out_buf
- else:
- return self.final_norm(hidden)
-
- # =========================================================================
- # Self-Speculative Decoding
- # =========================================================================
-
- def snapshot_kv_cache(self) -> list[tuple[np.ndarray, np.ndarray]]:
- """Snapshot all layer KV caches to CPU memory.
-
- Returns:
- List of (k_cache_np, v_cache_np) tuples, one per layer.
- Each cache is numpy array of shape [num_heads, max_seq_len, head_dim].
- """
- snapshot = []
- for block in self.blocks:
- k_np = block.attn._k_cache.to_numpy().copy()
- v_np = block.attn._v_cache.to_numpy().copy()
- snapshot.append((k_np, v_np))
- return snapshot
-
- def restore_kv_cache(self, snapshot: list[tuple[np.ndarray, np.ndarray]]) -> None:
- """Restore all layer KV caches from CPU snapshot.
-
- Args:
- snapshot: List of (k_cache_np, v_cache_np) tuples from snapshot_kv_cache().
-
- Note:
- This method copies data into existing arrays rather than replacing them.
- This is critical for CUDA Graph compatibility - the graph captures pointer
- addresses, so we must preserve the existing arrays.
- """
- for i, block in enumerate(self.blocks):
- k_np, v_np = snapshot[i]
- # Copy data into existing arrays (preserves pointers for CUDA Graph)
- k_np_typed: np.ndarray = k_np.astype(np.float16)
- v_np_typed: np.ndarray = v_np.astype(np.float16)
- block.attn._k_cache._get_native().copy_from_numpy(k_np_typed)
- block.attn._v_cache._get_native().copy_from_numpy(v_np_typed)
-
- def _draft_forward_early_layers(
- self,
- token_id: int,
- position: int,
- context_len: int,
- num_draft_layers: int,
- ) -> GPUArray:
- """Forward pass through only the first N layers (draft model).
-
- Uses the same KV cache as the full model but only updates early layers.
- After draft is done, the early layer KV entries need to be restored
- before running the full model verification.
-
- Args:
- token_id: Current token ID
- position: Position in sequence
- context_len: Total context length
- num_draft_layers: Number of early layers to use as draft
-
- Returns:
- Hidden states [1, hidden_size] after num_draft_layers
- """
- # Get token embedding
- if not hasattr(self, "_embed_np_cache"):
- self._embed_np_cache = self.embed_tokens.to_numpy()
- hidden_np = self._embed_np_cache[token_id : token_id + 1]
- hidden = from_numpy(hidden_np.astype(self._embed_np_cache.dtype))
-
- # Only run through first num_draft_layers blocks
- for i in range(min(num_draft_layers, len(self.blocks))):
- block = self.blocks[i]
- # Pre-norm
- residual = hidden
- hidden = block.attn_norm(hidden)
-
- # Attention with fixed cache
- hidden = block.attn.forward_fixed_cache(hidden, position, context_len)
- hidden = add(residual, hidden)
-
- # MLP
- residual = hidden
- hidden = block.mlp_norm(hidden)
- hidden = block.mlp(hidden)
- hidden = add(residual, hidden)
-
- # Note: We do NOT apply final_norm here since draft output
- # is only used for sampling, not for precise logits
- return hidden
-
- def _draft_get_logits(self, hidden: GPUArray) -> GPUArray:
- """Get logits from draft hidden states (after early layers).
-
- This applies final_norm and then computes logits.
- Note: The draft hidden states are from early layers, so the logits
- may not be identical to full model logits.
- """
- # Apply final norm (needed for proper logits computation)
- hidden_normed = self.final_norm(hidden)
- return self.get_logits(hidden_normed)
-
- def decode_step_self_speculative_lookahead(
- self,
- token_id: int,
- max_draft_tokens: int = 4,
- draft_layers: int = 8,
- ) -> tuple[list[int], dict]:
- """Self-speculative decode step with GPU-side lookahead KV (no CPU copies).
-
- Uses lookahead KV cache management to avoid CPU-GPU transfers.
-
- IMPORTANT: Before calling this method:
- 1. Run prefill and store KV using kv_cache_prefill_gqa()
- 2. Call set_lookahead_confirmed_pos(prefill_len) to mark prefill KV as committed
-
- Algorithm:
- 1. Generate draft tokens using early layers (writes to speculative positions)
- 2. Reset lookahead, verify with full model in batch
- 3. Accept tokens until first disagreement
- 4. Re-run for accepted tokens to ensure correct KV
- 5. Commit accepted tokens
-
- Args:
- token_id: Current token ID (the last accepted token)
- max_draft_tokens: Maximum number of draft tokens to generate
- draft_layers: Number of early layers to use as draft
-
- Returns:
- Tuple of:
- - accepted_tokens: List of accepted token IDs
- - stats: Dict with 'draft_count', 'accepted_count' for analysis
- """
- confirmed_pos = self.get_lookahead_confirmed_pos()
-
- # === Step 1: Generate draft tokens using early layers ===
- # Reset lookahead before draft phase
- self.reset_lookahead_all()
-
- draft_tokens = []
- current_token = token_id
-
- for i in range(max_draft_tokens):
- pos = confirmed_pos + i
- ctx = confirmed_pos + i + 1
- # Forward through early layers only
- hidden = self._draft_forward_early_layers(current_token, pos, ctx, draft_layers)
- logits = self._draft_get_logits(hidden)
- logits_np = logits.to_numpy()[-1]
- next_token = int(np.argmax(logits_np))
-
- draft_tokens.append(next_token)
- current_token = next_token
-
- # === Step 2: Reset and verify with full model in batch ===
- self.reset_lookahead_all()
-
- verify_input = [token_id] + draft_tokens[:-1]
- verify_ctx = confirmed_pos + len(verify_input)
-
- hidden_batch = self._decode_step_fixed_cache_batch(verify_input, confirmed_pos, verify_ctx)
- verify_logits = self.get_logits(hidden_batch)
- verify_logits_np = verify_logits.to_numpy()
-
- # === Step 3: Accept/Reject tokens ===
- accepted_tokens = []
- for i, draft_token in enumerate(draft_tokens):
- target_token = int(np.argmax(verify_logits_np[i]))
-
- if target_token == draft_token:
- accepted_tokens.append(draft_token)
- else:
- accepted_tokens.append(target_token)
- break
-
- # === Step 4: Re-run for accepted tokens if partial accept ===
- if len(accepted_tokens) < max_draft_tokens:
- self.reset_lookahead_all()
- # Use CUDA Graph if available
- use_graph = hasattr(self, "_decode_graph_ready") and self._decode_graph_ready
- current = token_id
- for i, acc_token in enumerate(accepted_tokens):
- pos = confirmed_pos + i
- ctx = confirmed_pos + i + 1
- if use_graph:
- self._decode_step_graph_replay(current, pos, ctx)
- else:
- self._decode_step_fixed_cache(current, pos, ctx)
- current = acc_token
-
- # === Step 5: Commit accepted tokens ===
- self.commit_lookahead_all(len(accepted_tokens))
-
- stats = {
- "draft_count": len(draft_tokens),
- "accepted_count": len(
- [
- t
- for i, t in enumerate(accepted_tokens)
- if i < len(draft_tokens) and t == draft_tokens[i]
- ]
- ),
- }
-
- return accepted_tokens, stats
-
- # =========================================================================
- # Lookahead KV Cache Management (GPU-side, no CPU copies)
- # =========================================================================
-
- def set_lookahead_confirmed_pos(self, pos: int) -> None:
- """Set confirmed position for all layers (e.g., after prefill).
-
- Args:
- pos: Position where KV is finalized (tokens 0 to pos-1 are committed).
- """
- for block in self.blocks:
- block.attn.set_confirmed_pos(pos)
-
- def reset_lookahead_all(self) -> None:
- """Reset lookahead pointer to confirmed position for all layers.
-
- Called at the start of each Jacobi iteration. This resets the write
- pointer without modifying KV cache - speculative positions will be
- overwritten by the next forward pass.
- """
- for block in self.blocks:
- block.attn.reset_lookahead()
-
- def commit_lookahead_all(self, n_accepted: int) -> None:
- """Commit accepted tokens for all layers.
-
- Args:
- n_accepted: Number of accepted tokens to commit.
- """
- for block in self.blocks:
- block.attn.commit_lookahead(n_accepted)
-
- def get_lookahead_confirmed_pos(self) -> int:
- """Get current confirmed position (from first layer)."""
- return self.blocks[0].attn.get_confirmed_pos()
-
- # =========================================================================
- # Jacobi Decoding
- # =========================================================================
-
- def _init_jacobi_guess(
- self,
- last_token: int,
- position: int,
- context_len: int,
- n_tokens: int,
- strategy: Literal["repeat", "ngram", "greedy"],
- ) -> list[int]:
- """Initialize guess tokens for Jacobi decoding.
-
- Args:
- last_token: The last accepted token
- position: Current position in sequence
- context_len: Current context length
- n_tokens: Number of tokens to guess
- strategy: Initialization strategy
- - "repeat": Repeat last_token n times
- - "ngram": Use n-gram cache (falls back to repeat if no match)
- - "greedy": Run greedy decode to get initial guess
-
- Returns:
- List of n_tokens guessed token IDs
- """
- if strategy == "repeat":
- return [last_token] * n_tokens
-
- elif strategy == "ngram":
- # N-gram cache lookup (simple implementation)
- # Check if we have this token in recent history
- if hasattr(self, "_ngram_cache") and last_token in self._ngram_cache:
- cached = self._ngram_cache[last_token]
- if len(cached) >= n_tokens:
- return cached[:n_tokens]
- # Fallback to repeat
- return [last_token] * n_tokens
-
- elif strategy == "greedy":
- # Run greedy sequential decode to get initial guess
- # This is expensive but gives best initial guess
- kv_snapshot = self.snapshot_kv_cache()
- guess = []
- pos = position
- ctx = context_len
- current = last_token
-
- for _ in range(n_tokens):
- hidden = self._decode_step_fixed_cache(current, pos, ctx)
- logits = self.get_logits(hidden)
- next_token = int(np.argmax(logits.to_numpy()[-1]))
- guess.append(next_token)
- current = next_token
- pos += 1
- ctx += 1
-
- # Restore KV cache
- self.restore_kv_cache(kv_snapshot)
- return guess
-
- else:
- raise ValueError(f"Unknown init strategy: {strategy}")
-
- # =========================================================================
- # Jacobi Decoding with Lookahead KV (GPU-side, no CPU copies)
- # =========================================================================
-
- def _init_jacobi_guess_lookahead(
- self,
- last_token: int,
- n_tokens: int,
- strategy: Literal["repeat", "ngram", "greedy"],
- ) -> list[int]:
- """Initialize guess tokens for Jacobi lookahead (no CPU copies).
-
- Args:
- last_token: The last accepted token
- n_tokens: Number of tokens to guess
- strategy: Initialization strategy
- - "repeat": Repeat last_token n times
- - "ngram": Use n-gram cache (falls back to repeat)
- - "greedy": Run greedy decode (writes to lookahead positions)
-
- Returns:
- List of n_tokens guessed token IDs
- """
- if strategy == "repeat":
- return [last_token] * n_tokens
-
- elif strategy == "ngram":
- if hasattr(self, "_ngram_cache") and last_token in self._ngram_cache:
- cached = self._ngram_cache[last_token]
- if len(cached) >= n_tokens:
- return cached[:n_tokens]
- return [last_token] * n_tokens
-
- elif strategy == "greedy":
- # Run greedy decode using lookahead positions
- # This writes KV at [confirmed_pos, confirmed_pos + n_tokens)
- confirmed_pos = self.get_lookahead_confirmed_pos()
- guess = []
- current = last_token
-
- for i in range(n_tokens):
- pos = confirmed_pos + i
- ctx = confirmed_pos + i + 1
- hidden = self._decode_step_fixed_cache(current, pos, ctx)
- logits = self.get_logits(hidden)
- next_token = int(np.argmax(logits.to_numpy()[-1]))
- guess.append(next_token)
- current = next_token
-
- # Reset lookahead after greedy init (KV will be overwritten)
- self.reset_lookahead_all()
- return guess
-
- else:
- raise ValueError(f"Unknown init strategy: {strategy}")
-
- def decode_step_jacobi_lookahead(
- self,
- token_id: int,
- n_tokens: int = 8,
- max_iter: int = 3,
- init_strategy: Literal["repeat", "ngram", "greedy"] = "repeat",
- ) -> tuple[list[int], dict]:
- """Jacobi decoding step with GPU-side lookahead KV (no CPU copies).
-
- This method uses the lookahead KV cache management to avoid all
- CPU-GPU memory transfers during Jacobi iterations.
-
- IMPORTANT: Before calling this method:
- 1. Run prefill and store KV using kv_cache_prefill_gqa()
- 2. Call set_lookahead_confirmed_pos(prefill_len) to mark prefill KV as committed
-
- Algorithm:
- 1. Initialize N future positions with a guess
- 2. Reset lookahead pointer (no KV modification)
- 3. Batch forward - writes KV at [confirmed_pos, confirmed_pos + n_tokens)
- 4. Update guess with argmax(logits)
- 5. Repeat until convergence or max_iter
- 6. Commit accepted tokens by advancing confirmed_pos
-
- Args:
- token_id: Current token ID (the last accepted token)
- n_tokens: Number of tokens to decode in parallel (default: 8)
- max_iter: Maximum iterations for convergence (default: 3)
- init_strategy: How to initialize guess tokens
- - "repeat": Repeat last token (fast, simple)
- - "ngram": Use n-gram cache if available
- - "greedy": Run greedy decode first (slow but accurate)
-
- Returns:
- Tuple of:
- - accepted_tokens: List of accepted token IDs
- - stats: Dict with 'iterations', 'converged', 'accepted_count'
- """
- # Get confirmed position (this is our starting point)
- confirmed_pos = self.get_lookahead_confirmed_pos()
-
- # Initialize guess (may use lookahead positions for greedy)
- guess = self._init_jacobi_guess_lookahead(token_id, n_tokens, init_strategy)
-
- iterations_used = 0
- converged = False
- prev_guess = None
-
- for iteration in range(max_iter):
- iterations_used = iteration + 1
-
- # Reset lookahead pointer (does NOT modify KV cache)
- self.reset_lookahead_all()
-
- # Batch forward: input [last_token, guess[0], ..., guess[n-2]]
- # produces logits for [guess[0], guess[1], ..., guess[n-1]]
- # Writes KV at [confirmed_pos, confirmed_pos + n_tokens)
- input_tokens = [token_id] + guess[:-1]
- start_pos = confirmed_pos
- ctx_len = confirmed_pos + len(input_tokens)
-
- hidden = self._decode_step_fixed_cache_batch(input_tokens, start_pos, ctx_len)
- logits = self.get_logits(hidden)
- logits_np = logits.to_numpy() # [n_tokens, vocab_size]
-
- # Update guess with argmax
- new_guess = [int(np.argmax(logits_np[i])) for i in range(n_tokens)]
-
- # Check full convergence
- if new_guess == guess:
- converged = True
- break
-
- prev_guess = guess
- guess = new_guess
-
- # Find longest converged prefix
- if converged:
- accepted_tokens = guess
- else:
- accepted_tokens = []
- if prev_guess is not None:
- for i in range(n_tokens):
- if guess[i] == prev_guess[i]:
- accepted_tokens.append(guess[i])
- else:
- break
- if len(accepted_tokens) == 0:
- accepted_tokens = [guess[0]]
-
- # Commit accepted tokens - this is the ONLY state change
- # The KV for accepted tokens is already written from the last iteration
- # We just need to run one more forward to ensure KV is correct
- self.reset_lookahead_all()
-
- # Re-run with just the accepted tokens to ensure KV is correct
- if len(accepted_tokens) < n_tokens:
- # KV may have extra speculative entries - need to overwrite with correct values
- # Run sequential for accepted tokens only
- # Use CUDA Graph if available
- use_graph = hasattr(self, "_decode_graph_ready") and self._decode_graph_ready
- current = token_id
- for i, acc_token in enumerate(accepted_tokens):
- pos = confirmed_pos + i
- ctx = confirmed_pos + i + 1
- if use_graph:
- self._decode_step_graph_replay(current, pos, ctx)
- else:
- self._decode_step_fixed_cache(current, pos, ctx)
- current = acc_token
- # If all converged, KV is already correct from last batch forward
-
- # Commit the accepted tokens
- self.commit_lookahead_all(len(accepted_tokens))
-
- # Update n-gram cache for future use
- if not hasattr(self, "_ngram_cache"):
- self._ngram_cache: dict[int, list[int]] = {}
- self._ngram_cache[token_id] = accepted_tokens.copy()
-
- stats = {
- "iterations": iterations_used,
- "converged": converged,
- "accepted_count": len(accepted_tokens),
- }
-
- return accepted_tokens, stats
-
-
-# =============================================================================
-# Type Aliases
-# =============================================================================
-
-# GPT2Model and LlamaModel are now simple aliases for CausalTransformerModel.
-# All models use CausalTransformerModel as the single runtime type.
-GPT2Model = CausalTransformerModel
-LlamaModel = CausalTransformerModel
-# Legacy component aliases (import from layers module)
-RMSNorm = Norm # Use Norm with norm_type="rmsnorm"
-LayerNorm = Norm # Use Norm with norm_type="layernorm"
-LlamaAttention = Attention
-LlamaMLP = MLP
-LlamaBlock = TransformerBlock
-CausalSelfAttention = Attention
+__all__ = [
+ # Primary model class
+ "CausalTransformerModel",
+ # Architecture aliases
+ "GPT2Model",
+ "LlamaModel",
+ # Legacy aliases
+ "RMSNorm",
+ "LayerNorm",
+ "LlamaAttention",
+ "LlamaMLP",
+ "LlamaBlock",
+ "CausalSelfAttention",
+]
diff --git a/src/pygpukit/llm/models/__init__.py b/src/pygpukit/llm/models/__init__.py
new file mode 100644
index 0000000..b18cf77
--- /dev/null
+++ b/src/pygpukit/llm/models/__init__.py
@@ -0,0 +1,34 @@
+"""LLM model implementations.
+
+This module provides unified transformer runtime implementations.
+"""
+
+from __future__ import annotations
+
+# Legacy component aliases (for backward compatibility)
+from pygpukit.llm.models.causal import (
+ CausalSelfAttention,
+ CausalTransformerModel,
+ GPT2Model,
+ LayerNorm,
+ LlamaAttention,
+ LlamaBlock,
+ LlamaMLP,
+ LlamaModel,
+ RMSNorm,
+)
+
+__all__ = [
+ # Primary model class
+ "CausalTransformerModel",
+ # Architecture aliases
+ "GPT2Model",
+ "LlamaModel",
+ # Legacy aliases
+ "RMSNorm",
+ "LayerNorm",
+ "LlamaAttention",
+ "LlamaMLP",
+ "LlamaBlock",
+ "CausalSelfAttention",
+]
diff --git a/src/pygpukit/llm/models/causal.py b/src/pygpukit/llm/models/causal.py
new file mode 100644
index 0000000..4b35245
--- /dev/null
+++ b/src/pygpukit/llm/models/causal.py
@@ -0,0 +1,1501 @@
+"""CausalTransformerModel implementation for PyGPUkit.
+
+Provides the unified Transformer runtime for GPT-2, LLaMA, and Qwen3 architectures.
+Model-specific behavior is controlled by the ModelSpec configuration.
+
+Key features:
+- Hybrid Attention: CPU for seq_len=1 (decode), GPU for prefill
+- GPU-native operations: RMSNorm, LayerNorm, SDPA, SiLU, GELU, RoPE
+- CUDA Graph support for zero-allocation decode
+- Speculative and Jacobi decoding modes
+"""
+
+from __future__ import annotations
+
+from collections.abc import Generator
+from typing import TYPE_CHECKING, Literal
+
+import numpy as np
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.factory import from_numpy
+
+# Import from refactored modules
+from pygpukit.llm.buffers import DecodeBuffers, PrefillBuffers
+from pygpukit.llm.config import ModelSpec, TransformerConfig
+from pygpukit.llm.layers import (
+ MLP,
+ Attention,
+ Norm,
+ TransformerBlock,
+)
+from pygpukit.llm.sampling import sample_token
+from pygpukit.ops.basic import (
+ add,
+ add_inplace,
+ bias_add_inplace,
+ copy_to,
+ embedding_lookup,
+ embedding_lookup_ptr,
+ gelu,
+ kv_cache_update_gqa,
+ kv_cache_update_gqa_ptr,
+ matmul,
+ mul_inplace,
+ repeat_interleave_axis1,
+ reshape_copy,
+ rmsnorm,
+ rope_inplace,
+ sample_token_gpu,
+ sdpa_causal,
+ sdpa_causal_fixed_cache,
+ sdpa_causal_fixed_cache_ptr,
+ silu,
+ transpose,
+ transpose_3d_021,
+)
+
+if TYPE_CHECKING:
+ pass
+
+
+def _to_float32_logits(logits_np: np.ndarray) -> np.ndarray:
+ """Convert logits to float32 for sampling.
+
+ If logits are stored as uint16 (bfloat16 representation), convert them
+ to float32. Otherwise return as-is.
+ """
+ if logits_np.dtype == np.uint16:
+ # bfloat16 stored as uint16: convert to float32
+ return (logits_np.astype(np.uint32) << 16).view(np.float32)
+ return logits_np.astype(np.float32)
+
+
+# =============================================================================
+# Unified CausalTransformerModel
+# =============================================================================
+
+
+class CausalTransformerModel:
+ """Unified causal transformer model.
+
+ The single runtime model for all architectures (GPT-2, LLaMA, Qwen3).
+ Model-specific behavior is controlled by the spec attribute.
+ """
+
+ # Type hints for dynamically added attributes
+ _batch_decode_buffers: DecodeBuffers | None
+ _batch_token_ids_np: np.ndarray
+
+ def __init__(
+ self,
+ config: TransformerConfig,
+ embed_tokens: GPUArray,
+ blocks: list[TransformerBlock],
+ final_norm: Norm,
+ lm_head: GPUArray | None = None,
+ position_embed: GPUArray | None = None, # For GPT-2 style
+ spec: ModelSpec | None = None,
+ ):
+ self.config = config
+ self.embed_tokens = embed_tokens
+ self.blocks = blocks
+ self.final_norm = final_norm
+ self._lm_head = lm_head
+ self.position_embed = position_embed
+ self.spec = spec
+
+ def __call__(
+ self,
+ input_ids: list[int],
+ position_ids: list[int] | None = None,
+ past_key_values: list[tuple | None] | None = None,
+ use_cache: bool = False,
+ ) -> tuple[GPUArray, list[tuple | None] | None]:
+ """Forward pass.
+
+ Args:
+ input_ids: Token IDs [seq_len]
+ position_ids: Position IDs (auto-generated if None)
+ past_key_values: List of (k, v) tuples per layer
+ use_cache: Whether to return KV cache
+
+ Returns:
+ Tuple of (hidden_states, present_key_values)
+ """
+ seq_len = len(input_ids)
+
+ if position_ids is None:
+ if past_key_values is not None and past_key_values[0] is not None:
+ past_len = past_key_values[0][0].shape[0]
+ position_ids = list(range(past_len, past_len + seq_len))
+ else:
+ position_ids = list(range(seq_len))
+
+ # Token embeddings (cache numpy array to avoid repeated GPU->CPU transfer)
+ if not hasattr(self, "_embed_np_cache"):
+ self._embed_np_cache = self.embed_tokens.to_numpy()
+ hidden_np = self._embed_np_cache[input_ids]
+
+ # Add position embeddings (GPT-2 style)
+ if self.position_embed is not None:
+ if not hasattr(self, "_pos_embed_np_cache"):
+ self._pos_embed_np_cache = self.position_embed.to_numpy()
+ hidden_np = hidden_np + self._pos_embed_np_cache[position_ids]
+
+ hidden: GPUArray = from_numpy(hidden_np.astype(self._embed_np_cache.dtype))
+
+ # Transformer blocks
+ present_key_values = []
+ for i, block in enumerate(self.blocks):
+ past_kv = past_key_values[i] if past_key_values else None
+ hidden, present_kv = block(hidden, position_ids, past_kv, use_cache)
+ present_key_values.append(present_kv)
+
+ # Final norm
+ hidden = self.final_norm(hidden)
+
+ if use_cache:
+ return hidden, present_key_values
+ return hidden, None
+
+ @property
+ def lm_head(self) -> GPUArray | None:
+ """LM head weights (for backward compatibility)."""
+ return self._lm_head
+
+ def get_logits(self, hidden: GPUArray) -> GPUArray:
+ """Compute logits from hidden states on GPU."""
+ # Cache transposed lm_head to avoid repeated transpose
+ if not hasattr(self, "_lm_head_t_cache"):
+ lm_head = self._lm_head if self._lm_head is not None else self.embed_tokens
+ self._lm_head_t_cache = transpose(lm_head)
+
+ # GPU matmul: hidden @ lm_head.T
+ # hidden: [seq_len, hidden_size], lm_head: [vocab_size, hidden_size]
+ # Result: [seq_len, vocab_size]
+ return matmul(hidden, self._lm_head_t_cache)
+
+ def generate(
+ self,
+ input_ids: list[int],
+ max_new_tokens: int = 20,
+ temperature: float = 1.0,
+ top_k: int = 50,
+ top_p: float = 0.9,
+ eos_token_id: int | None = None,
+ use_cache: bool = True,
+ gpu_sampling: bool = False,
+ ) -> list[int]:
+ """Generate tokens autoregressively.
+
+ Args:
+ input_ids: Initial token IDs
+ max_new_tokens: Maximum new tokens to generate
+ temperature: Sampling temperature
+ top_k: Top-k filtering
+ top_p: Nucleus sampling threshold
+ eos_token_id: Stop at this token
+ use_cache: Use KV cache
+ gpu_sampling: Use GPU-based sampling (avoids full logits D2H transfer)
+
+ Returns:
+ List of all token IDs (input + generated)
+ """
+ tokens = list(input_ids)
+ past_key_values = None
+
+ if use_cache:
+ # Prefill
+ hidden, past_key_values = self(tokens, use_cache=True)
+ logits = self.get_logits(hidden)
+
+ if gpu_sampling:
+ # GPU sampling: only transfer 1 int instead of full vocab logits
+ next_token = sample_token_gpu(logits[-1], temperature, top_k, top_p)
+ else:
+ last_logits = _to_float32_logits(logits.to_numpy()[-1])
+ next_token = sample_token(last_logits, temperature, top_k, top_p)
+ tokens.append(next_token)
+
+ if eos_token_id is not None and next_token == eos_token_id:
+ return tokens
+
+ # Decode
+ for _ in range(max_new_tokens - 1):
+ hidden, past_key_values = self(
+ [next_token], past_key_values=past_key_values, use_cache=True
+ )
+ logits = self.get_logits(hidden)
+
+ if gpu_sampling:
+ next_token = sample_token_gpu(logits[-1], temperature, top_k, top_p)
+ else:
+ last_logits = _to_float32_logits(logits.to_numpy()[-1])
+ next_token = sample_token(last_logits, temperature, top_k, top_p)
+ tokens.append(next_token)
+
+ if eos_token_id is not None and next_token == eos_token_id:
+ break
+ else:
+ for _ in range(max_new_tokens):
+ hidden, _ = self(tokens, use_cache=False)
+ logits = self.get_logits(hidden)
+
+ if gpu_sampling:
+ next_token = sample_token_gpu(logits[-1], temperature, top_k, top_p)
+ else:
+ last_logits = _to_float32_logits(logits.to_numpy()[-1])
+ next_token = sample_token(last_logits, temperature, top_k, top_p)
+ tokens.append(next_token)
+
+ if eos_token_id is not None and next_token == eos_token_id:
+ break
+
+ return tokens
+
+ def generate_stream(
+ self,
+ input_ids: list[int],
+ max_new_tokens: int = 20,
+ temperature: float = 1.0,
+ top_k: int = 50,
+ top_p: float = 0.9,
+ eos_token_id: int | None = None,
+ gpu_sampling: bool = False,
+ ) -> Generator[int, None, None]:
+ """Generate tokens autoregressively with streaming.
+
+ Yields tokens one at a time as they are generated, enabling
+ real-time text display in chat applications.
+
+ Args:
+ input_ids: Initial token IDs
+ max_new_tokens: Maximum new tokens to generate
+ temperature: Sampling temperature
+ top_k: Top-k filtering
+ top_p: Nucleus sampling threshold
+ eos_token_id: Stop at this token
+ gpu_sampling: Use GPU-based sampling (avoids full logits D2H transfer)
+
+ Yields:
+ Generated token IDs one at a time
+
+ Example:
+ >>> for token_id in model.generate_stream(input_ids, max_new_tokens=50):
+ ... token_str = tokenizer.decode([token_id])
+ ... print(token_str, end="", flush=True)
+ """
+ past_key_values = None
+
+ # Prefill
+ hidden, past_key_values = self(input_ids, use_cache=True)
+ logits = self.get_logits(hidden)
+
+ if gpu_sampling:
+ next_token = sample_token_gpu(logits[-1], temperature, top_k, top_p)
+ else:
+ last_logits = _to_float32_logits(logits.to_numpy()[-1])
+ next_token = sample_token(last_logits, temperature, top_k, top_p)
+
+ yield next_token
+
+ if eos_token_id is not None and next_token == eos_token_id:
+ return
+
+ # Decode
+ for _ in range(max_new_tokens - 1):
+ hidden, past_key_values = self(
+ [next_token], past_key_values=past_key_values, use_cache=True
+ )
+ logits = self.get_logits(hidden)
+
+ if gpu_sampling:
+ next_token = sample_token_gpu(logits[-1], temperature, top_k, top_p)
+ else:
+ last_logits = _to_float32_logits(logits.to_numpy()[-1])
+ next_token = sample_token(last_logits, temperature, top_k, top_p)
+
+ yield next_token
+
+ if eos_token_id is not None and next_token == eos_token_id:
+ return
+
+ def _decode_step_zero_alloc(
+ self,
+ token_id: int,
+ position: int,
+ context_len: int,
+ buffers: DecodeBuffers,
+ ) -> GPUArray:
+ """Single decode step with zero memory allocations.
+
+ Uses pre-allocated DecodeBuffers for all intermediate computations.
+ All operations write to pre-allocated buffers, no new GPU memory is allocated.
+
+ Args:
+ token_id: Current token ID
+ position: Position in sequence
+ context_len: Total context length
+ buffers: Pre-allocated decode buffers
+
+ Returns:
+ Hidden states [1, hidden_size]
+ """
+ # Get token embedding directly to hidden (no copy needed)
+ embedding_lookup(self.embed_tokens, buffers.hidden, token_id)
+
+ # Transformer blocks with fixed cache
+ for block in self.blocks:
+ # Pre-norm: hidden -> norm_out
+ rmsnorm(
+ buffers.hidden, block.attn_norm.weight, block.attn_norm.eps, out=buffers.norm_out
+ )
+
+ # Save residual
+ copy_to(buffers.hidden, buffers.residual)
+
+ # Attention with fixed cache (writes to buffers.hidden)
+ self._attention_forward_zero_alloc(
+ block.attn, buffers.norm_out, position, context_len, buffers
+ )
+
+ # Add residual: hidden = residual + hidden
+ add_inplace(buffers.hidden, buffers.residual)
+
+ # MLP pre-norm
+ copy_to(buffers.hidden, buffers.residual)
+ rmsnorm(buffers.hidden, block.mlp_norm.weight, block.mlp_norm.eps, out=buffers.norm_out)
+
+ # MLP forward (SwiGLU)
+ self._mlp_forward_zero_alloc(block.mlp, buffers.norm_out, buffers)
+
+ # Add residual
+ add_inplace(buffers.hidden, buffers.residual)
+
+ # Final norm
+ rmsnorm(buffers.hidden, self.final_norm.weight, self.final_norm.eps, out=buffers.norm_out)
+ copy_to(buffers.norm_out, buffers.hidden)
+
+ return buffers.hidden
+
+ def _attention_forward_zero_alloc(
+ self,
+ attn: Attention,
+ x: GPUArray,
+ position: int,
+ context_len: int,
+ buffers: DecodeBuffers,
+ use_position_ptr: bool = False,
+ use_context_len_ptr: bool = False,
+ max_kv_len: int | None = None,
+ ) -> None:
+ """Attention forward pass with zero allocations.
+
+ Result is written to buffers.hidden.
+
+ Args:
+ use_position_ptr: If True, read position from buffers.position_buf
+ (for CUDA Graph replay without recapture).
+ use_context_len_ptr: If True, read context_len from buffers.context_len_buf
+ (for CUDA Graph replay without recapture).
+ max_kv_len: Maximum KV length for CUDA Graph shared memory allocation.
+ Required if use_context_len_ptr=True.
+ """
+ # Fused QKV projection (1 matmul replaces 3, then zero-copy narrow views)
+ # This is 4x faster for M=1 with cuBLASLt due to reduced kernel launch overhead
+ attn.qkv_proj(x, out=buffers.qkv_proj_out)
+
+ # Apply biases (fused projection has no bias)
+ if attn.q_proj.bias is not None:
+ bias_add_inplace(buffers.q_view, attn.q_proj.bias)
+ if attn.k_proj.bias is not None:
+ bias_add_inplace(buffers.k_view, attn.k_proj.bias)
+ if attn.v_proj.bias is not None:
+ bias_add_inplace(buffers.v_view, attn.v_proj.bias)
+
+ # Reshape narrow views to 3D using pre-allocated buffers
+ # q_view, k_view, v_view are pre-created zero-copy views of qkv_proj_out
+ reshape_copy(buffers.q_view, (1, attn.num_heads, attn.head_dim), out=buffers.q)
+ reshape_copy(buffers.k_view, (1, attn.num_kv_heads, attn.head_dim), out=buffers.k)
+ reshape_copy(buffers.v_view, (1, attn.num_kv_heads, attn.head_dim), out=buffers.v)
+ q, k, v = buffers.q, buffers.k, buffers.v
+
+ # QK Norm (Qwen3) - zero allocation using pre-allocated buffers
+ if attn.q_norm is not None and buffers.q_2d is not None and buffers.q_flat is not None:
+ # Reshape q [1,H,D] -> q_flat [H,D], apply norm, reshape back to q [1,H,D]
+ reshape_copy(q, (attn.num_heads, attn.head_dim), out=buffers.q_flat)
+ rmsnorm(buffers.q_flat, attn.q_norm.weight, attn.q_norm.eps, out=buffers.q_2d)
+ reshape_copy(buffers.q_2d, (1, attn.num_heads, attn.head_dim), out=buffers.q)
+ q = buffers.q
+ if attn.k_norm is not None and buffers.k_2d is not None and buffers.k_flat is not None:
+ # Reshape k [1,H,D] -> k_flat [H,D], apply norm, reshape back to k [1,H,D]
+ reshape_copy(k, (attn.num_kv_heads, attn.head_dim), out=buffers.k_flat)
+ rmsnorm(buffers.k_flat, attn.k_norm.weight, attn.k_norm.eps, out=buffers.k_2d)
+ reshape_copy(buffers.k_2d, (1, attn.num_kv_heads, attn.head_dim), out=buffers.k)
+ k = buffers.k
+
+ # Apply RoPE using pre-computed GPU tables (zero allocation)
+ if self.config.use_rope and hasattr(self, "_rope_cos_gpu"):
+ # Extract single row from pre-computed tables using GPU kernel
+ if use_position_ptr and buffers.position_buf is not None:
+ # Use _ptr variants for CUDA Graph replay
+ embedding_lookup_ptr(self._rope_cos_gpu, buffers.cos, buffers.position_buf)
+ embedding_lookup_ptr(self._rope_sin_gpu, buffers.sin, buffers.position_buf)
+ else:
+ embedding_lookup(self._rope_cos_gpu, buffers.cos, position)
+ embedding_lookup(self._rope_sin_gpu, buffers.sin, position)
+ # buffers.cos/sin are already [1, head_dim] - use directly
+ rope_inplace(q, k, buffers.cos, buffers.sin)
+
+ # Update KV cache at position (GQA-expanded, transposed)
+ if use_position_ptr and buffers.position_buf is not None:
+ # Use _ptr variants for CUDA Graph replay
+ kv_cache_update_gqa_ptr(k, attn._k_cache, attn.num_heads, buffers.position_buf)
+ kv_cache_update_gqa_ptr(v, attn._v_cache, attn.num_heads, buffers.position_buf)
+ else:
+ kv_cache_update_gqa(k, attn._k_cache, attn.num_heads, position)
+ kv_cache_update_gqa(v, attn._v_cache, attn.num_heads, position)
+
+ # Transpose Q for SDPA: [1, num_heads, head_dim] -> [num_heads, 1, head_dim]
+ transpose_3d_021(q, out=buffers.q_t)
+
+ # SDPA with fixed cache
+ if use_context_len_ptr and buffers.context_len_buf is not None:
+ # Use pointer-based SDPA for CUDA Graph replay
+ assert max_kv_len is not None, "max_kv_len required for CUDA Graph mode"
+ sdpa_causal_fixed_cache_ptr(
+ buffers.q_t,
+ attn._k_cache,
+ attn._v_cache,
+ buffers.attn_out,
+ buffers.context_len_buf,
+ max_kv_len,
+ )
+ else:
+ sdpa_causal_fixed_cache(
+ buffers.q_t, attn._k_cache, attn._v_cache, buffers.attn_out, context_len
+ )
+
+ # Transpose output: [num_heads, 1, head_dim] -> [1, num_heads, head_dim]
+ transpose_3d_021(buffers.attn_out, out=buffers.q) # Reuse q buffer for transposed output
+
+ # Reshape to 2D: [1, hidden_size] - reuse q_proj_out buffer
+ reshape_copy(buffers.q, (1, attn.num_heads * attn.head_dim), out=buffers.q_proj_out)
+
+ # Output projection directly to hidden (eliminates copy)
+ attn.o_proj(buffers.q_proj_out, out=buffers.hidden)
+
+ def _mlp_forward_zero_alloc(
+ self,
+ mlp: MLP,
+ x: GPUArray,
+ buffers: DecodeBuffers,
+ ) -> None:
+ """MLP forward pass with zero allocations (SwiGLU).
+
+ Result is written to buffers.hidden.
+ """
+ if mlp.activation == "silu":
+ # Non-fused SwiGLU (2 separate matmuls) - for debugging
+ mlp.gate_proj(x, out=buffers.mlp_gate)
+ silu(buffers.mlp_gate, out=buffers.mlp_gate)
+
+ mlp.up_proj(x, out=buffers.mlp_up)
+
+ mul_inplace(buffers.mlp_gate, buffers.mlp_up)
+
+ mlp.down_proj(buffers.mlp_gate, out=buffers.hidden)
+ else:
+ # GELU path (GPT-2) - still has allocations, rarely used
+ fc1_out = mlp.fc1(x)
+ gelu_out = gelu(fc1_out)
+ fc2_out = mlp.fc2(gelu_out)
+ copy_to(fc2_out, buffers.hidden)
+
+ def _mlp_forward_batch_zero_alloc(
+ self,
+ mlp: MLP,
+ x: GPUArray,
+ buffers: DecodeBuffers,
+ out: GPUArray,
+ ) -> None:
+ """Batch MLP forward pass with zero allocations (SwiGLU).
+
+ Uses fused gate_up projection for efficiency.
+
+ Args:
+ mlp: MLP module
+ x: Input tensor [seq_len, hidden_size]
+ buffers: Pre-allocated decode buffers
+ out: Output buffer [seq_len, hidden_size] to write result
+ """
+ seq_len = x.shape[0]
+
+ if mlp.activation == "silu":
+ # Fused gate_up projection
+ gate_up_out = buffers.gate_up_out_batch.slice_rows(seq_len)
+ mlp.gate_up_proj(x, out=gate_up_out)
+
+ # Split into gate and up using narrow
+ intermediate_size = mlp.intermediate_size
+ gate = gate_up_out.narrow(0, intermediate_size) # [seq_len, intermediate_size]
+ up = gate_up_out.narrow(intermediate_size, intermediate_size)
+
+ # SiLU in-place on gate
+ silu(gate, out=gate)
+
+ # Multiply gate * up in-place
+ mul_inplace(gate, up)
+
+ # Down projection to output buffer
+ mlp.down_proj(gate, out=out)
+ else:
+ # GELU path - still has allocations (rarely used)
+ fc1_out = mlp.fc1(x)
+ gelu_out = gelu(fc1_out)
+ mlp.fc2(gelu_out, out=out)
+
+ def _prefill_with_buffers(
+ self,
+ input_ids: list[int],
+ buffers: PrefillBuffers,
+ use_cache: bool = True,
+ ) -> tuple[GPUArray, list[tuple | None] | None]:
+ """Prefill forward pass with reduced allocations using pre-allocated buffers.
+
+ Uses PrefillBuffers for projection outputs, attention intermediates, and MLP
+ to reduce memory allocations during prefill. Full zero-allocation requires
+ kernel-level support for partial buffer operations.
+
+ Args:
+ input_ids: Token IDs [seq_len]
+ buffers: Pre-allocated prefill buffers
+ use_cache: Whether to return KV cache
+
+ Returns:
+ Tuple of (hidden_states, present_key_values)
+ """
+ seq_len = len(input_ids)
+ assert seq_len <= buffers.max_seq_len, (
+ f"seq_len {seq_len} > max_seq_len {buffers.max_seq_len}"
+ )
+
+ position_ids = list(range(seq_len))
+
+ # Token embeddings - copy to pre-allocated buffer
+ if not hasattr(self, "_embed_np_cache"):
+ self._embed_np_cache = self.embed_tokens.to_numpy()
+ hidden_np = self._embed_np_cache[input_ids]
+
+ # Add position embeddings (GPT-2 style)
+ if self.position_embed is not None:
+ if not hasattr(self, "_pos_embed_np_cache"):
+ self._pos_embed_np_cache = self.position_embed.to_numpy()
+ hidden_np = hidden_np + self._pos_embed_np_cache[position_ids]
+
+ # Copy to pre-allocated hidden buffer
+ hidden = from_numpy(hidden_np.astype(self._embed_np_cache.dtype))
+ copy_to(hidden, buffers.hidden)
+
+ # Transformer blocks with buffer reuse
+ present_key_values = []
+ for block in self.blocks:
+ # Process using buffers where possible
+ hidden, present_kv = self._prefill_block_with_buffers(
+ block, buffers.hidden, position_ids, buffers, use_cache
+ )
+ present_key_values.append(present_kv)
+
+ # Final norm - reuse norm_out buffer
+ rmsnorm(buffers.hidden, self.final_norm.weight, self.final_norm.eps, out=buffers.norm_out)
+ copy_to(buffers.norm_out, buffers.hidden)
+
+ if use_cache:
+ return buffers.hidden, present_key_values
+ return buffers.hidden, None
+
+ def _prefill_block_with_buffers(
+ self,
+ block: TransformerBlock,
+ hidden: GPUArray,
+ position_ids: list[int],
+ buffers: PrefillBuffers,
+ use_cache: bool,
+ ) -> tuple[GPUArray, tuple | None]:
+ """Single transformer block forward with buffer reuse.
+
+ Args:
+ block: TransformerBlock to process
+ hidden: Input hidden states [seq_len, hidden_size]
+ position_ids: Position IDs for RoPE
+ buffers: Pre-allocated prefill buffers
+ use_cache: Whether to return KV cache
+
+ Returns:
+ Tuple of (output_hidden, present_kv)
+ """
+ # Attention block
+ # Pre-norm -> norm_out
+ rmsnorm(hidden, block.attn_norm.weight, block.attn_norm.eps, out=buffers.norm_out)
+
+ # Save residual
+ copy_to(hidden, buffers.residual)
+
+ # Attention forward with buffers
+ attn_out, present_kv = self._prefill_attention_with_buffers(
+ block.attn, buffers.norm_out, position_ids, buffers, use_cache
+ )
+
+ # Residual connection: hidden = residual + attn_out
+ add_inplace(attn_out, buffers.residual)
+ copy_to(attn_out, buffers.hidden)
+
+ # MLP block
+ # Pre-norm
+ copy_to(buffers.hidden, buffers.residual)
+ rmsnorm(buffers.hidden, block.mlp_norm.weight, block.mlp_norm.eps, out=buffers.norm_out)
+
+ # MLP forward with buffers
+ self._prefill_mlp_with_buffers(block.mlp, buffers.norm_out, buffers)
+
+ # Residual connection
+ add_inplace(buffers.hidden, buffers.residual)
+
+ return buffers.hidden, present_kv
+
+ def _prefill_attention_with_buffers(
+ self,
+ attn: Attention,
+ x: GPUArray,
+ position_ids: list[int],
+ buffers: PrefillBuffers,
+ use_cache: bool,
+ ) -> tuple[GPUArray, tuple | None]:
+ """Attention forward pass with buffer reuse during prefill.
+
+ Args:
+ attn: Attention layer
+ x: Input [seq_len, hidden_size]
+ position_ids: Position IDs for RoPE
+ buffers: Pre-allocated prefill buffers
+ use_cache: Whether to return KV cache
+
+ Returns:
+ Tuple of (output, present_kv)
+ """
+ seq_len = x.shape[0]
+
+ # Project Q, K, V using pre-allocated buffers
+ attn.q_proj(x, out=buffers.q_proj_out)
+ attn.k_proj(x, out=buffers.k_proj_out)
+ attn.v_proj(x, out=buffers.v_proj_out)
+
+ # Reshape to 3D
+ reshape_copy(buffers.q_proj_out, out=buffers.q)
+ reshape_copy(buffers.k_proj_out, out=buffers.k)
+ reshape_copy(buffers.v_proj_out, out=buffers.v)
+ q, k, v = buffers.q, buffers.k, buffers.v
+
+ # QK Norm (Qwen3 style)
+ if attn.q_norm is not None and buffers.q_2d is not None:
+ q_2d = reshape_copy(q, (seq_len * attn.num_heads, attn.head_dim))
+ q_2d = attn.q_norm(q_2d)
+ q = reshape_copy(q_2d, (seq_len, attn.num_heads, attn.head_dim))
+ if attn.k_norm is not None and buffers.k_2d is not None:
+ k_2d = reshape_copy(k, (seq_len * attn.num_kv_heads, attn.head_dim))
+ k_2d = attn.k_norm(k_2d)
+ k = reshape_copy(k_2d, (seq_len, attn.num_kv_heads, attn.head_dim))
+
+ # Apply RoPE
+ if self.config.use_rope and attn._cos is not None and attn._sin is not None:
+ # Use Attention's precomputed cos/sin tables
+ q_dtype = q.dtype
+ if q_dtype == "float16":
+ cos = from_numpy(attn._cos[position_ids].astype(np.float16))
+ sin = from_numpy(attn._sin[position_ids].astype(np.float16))
+ elif q_dtype == "bfloat16":
+ # Fall back to float32 computation for bfloat16
+ cos = from_numpy(attn._cos[position_ids].astype(np.float32))
+ sin = from_numpy(attn._sin[position_ids].astype(np.float32))
+ else:
+ # FP32 path
+ cos = from_numpy(attn._cos[position_ids].astype(np.float32))
+ sin = from_numpy(attn._sin[position_ids].astype(np.float32))
+ # Apply RoPE in-place (FP32 and FP16 have native kernel support)
+ if q_dtype in ("float32", "float16"):
+ rope_inplace(q, k, cos, sin)
+
+ # Store for KV cache - MUST copy since buffers.k/v are reused across layers
+ if use_cache:
+ # Create copies of K, V to avoid aliasing
+ # (shared buffers get overwritten by later layers)
+ k_copy = reshape_copy(k, k.shape)
+ v_copy = reshape_copy(v, v.shape)
+ present_kv = (k_copy, v_copy)
+ else:
+ present_kv = None
+
+ # Expand for GQA
+ if attn.num_kv_groups > 1:
+ k_expanded = repeat_interleave_axis1(k, attn.num_kv_groups)
+ v_expanded = repeat_interleave_axis1(v, attn.num_kv_groups)
+ else:
+ k_expanded = k
+ v_expanded = v
+
+ # Transpose for SDPA: [seq, heads, dim] -> [heads, seq, dim]
+ transpose_3d_021(q, out=buffers.q_t)
+ k_t = transpose_3d_021(k_expanded) # Can't use buffer due to GQA expansion
+ v_t = transpose_3d_021(v_expanded)
+
+ # SDPA with causal mask
+ sdpa_causal(buffers.q_t, k_t, v_t, out=buffers.attn_out)
+
+ # Transpose back and reshape
+ transpose_3d_021(buffers.attn_out, out=buffers.attn_out_t)
+ reshape_copy(buffers.attn_out_t, out=buffers.attn_out_2d)
+
+ # Output projection
+ attn.o_proj(buffers.attn_out_2d, out=buffers.o_proj_out)
+
+ return buffers.o_proj_out, present_kv
+
+ def _prefill_mlp_with_buffers(
+ self,
+ mlp: MLP,
+ x: GPUArray,
+ buffers: PrefillBuffers,
+ ) -> None:
+ """MLP forward pass with buffer reuse during prefill.
+
+ Result is written to buffers.hidden.
+
+ Args:
+ mlp: MLP layer
+ x: Input [seq_len, hidden_size]
+ buffers: Pre-allocated prefill buffers
+ """
+ if mlp.activation == "silu":
+ # SwiGLU: gate_proj -> SiLU -> * up_proj -> down_proj
+ mlp.gate_proj(x, out=buffers.mlp_gate)
+ silu(buffers.mlp_gate, out=buffers.mlp_gate)
+
+ mlp.up_proj(x, out=buffers.mlp_up)
+
+ # Element-wise multiply in-place
+ mul_inplace(buffers.mlp_gate, buffers.mlp_up)
+
+ # Down projection
+ mlp.down_proj(buffers.mlp_gate, out=buffers.mlp_down)
+ copy_to(buffers.mlp_down, buffers.hidden)
+ else:
+ # GELU path (GPT-2)
+ fc1_out = mlp.fc1(x)
+ gelu_out = gelu(fc1_out)
+ fc2_out = mlp.fc2(gelu_out)
+ copy_to(fc2_out, buffers.hidden)
+
+ def _decode_step_fixed_cache(
+ self,
+ token_id: int,
+ position: int,
+ context_len: int,
+ ) -> GPUArray:
+ """Single decode step using fixed-length KV cache (legacy, with allocations).
+
+ Args:
+ token_id: Current token ID
+ position: Position in sequence
+ context_len: Total context length
+
+ Returns:
+ Hidden states [1, hidden_size]
+ """
+ # Get token embedding
+ if not hasattr(self, "_embed_np_cache"):
+ self._embed_np_cache = self.embed_tokens.to_numpy()
+ hidden_np = self._embed_np_cache[token_id : token_id + 1]
+ hidden = from_numpy(hidden_np.astype(self._embed_np_cache.dtype))
+
+ # Transformer blocks with fixed cache
+ for block in self.blocks:
+ # Pre-norm
+ residual = hidden
+ hidden = block.attn_norm(hidden)
+
+ # Attention with fixed cache
+ hidden = block.attn.forward_fixed_cache(hidden, position, context_len)
+ hidden = add(residual, hidden)
+
+ # MLP
+ residual = hidden
+ hidden = block.mlp_norm(hidden)
+ hidden = block.mlp(hidden)
+ hidden = add(residual, hidden)
+
+ # Final norm
+ hidden = self.final_norm(hidden)
+
+ return hidden
+
+ def _decode_step_fixed_cache_batch(
+ self,
+ token_ids: list[int],
+ start_position: int,
+ context_len: int,
+ ) -> GPUArray:
+ """Batch decode step using fixed-length KV cache.
+
+ Processes multiple tokens at once for speculative decoding verification.
+
+ Args:
+ token_ids: List of token IDs to decode [seq_len tokens]
+ start_position: Starting position in sequence (first token's position)
+ context_len: Total context length after adding this batch
+ (should equal start_position + len(token_ids))
+
+ Returns:
+ Hidden states [seq_len, hidden_size]
+ """
+ # Dispatch to optimized single-token path for M=1
+ if len(token_ids) == 1:
+ return self._decode_step_fixed_cache(token_ids[0], start_position, context_len)
+
+ # M > 1: Batch decode path
+ # Get token embeddings for batch
+ if not hasattr(self, "_embed_np_cache"):
+ self._embed_np_cache = self.embed_tokens.to_numpy()
+ hidden_np = self._embed_np_cache[token_ids] # [seq_len, hidden_size]
+ hidden = from_numpy(hidden_np.astype(self._embed_np_cache.dtype))
+
+ # Transformer blocks with fixed cache (batch)
+ for block in self.blocks:
+ # Pre-norm
+ residual = hidden
+ hidden = block.attn_norm(hidden)
+
+ # Attention with fixed cache (batch)
+ hidden = block.attn.forward_fixed_cache_batch(hidden, start_position, context_len)
+ hidden = add(residual, hidden)
+
+ # MLP
+ residual = hidden
+ hidden = block.mlp_norm(hidden)
+ hidden = block.mlp(hidden)
+ hidden = add(residual, hidden)
+
+ # Final norm
+ hidden = self.final_norm(hidden)
+
+ return hidden
+
+ def _decode_step_fixed_cache_batch_zero_alloc(
+ self,
+ token_ids: list[int],
+ start_position: int,
+ context_len: int,
+ buffers: DecodeBuffers,
+ ) -> GPUArray:
+ """Batch decode step using pre-allocated buffers (zero-allocation).
+
+ This function is designed to be CUDA Graph capture compatible.
+ All intermediate buffers are pre-allocated in DecodeBuffers.
+
+ Args:
+ token_ids: List of token IDs to decode [seq_len tokens]
+ start_position: Starting position in sequence (first token's position)
+ context_len: Total context length after adding this batch
+ buffers: Pre-allocated batch decode buffers
+
+ Returns:
+ Hidden states [seq_len, hidden_size] (view into buffers.hidden_batch)
+
+ Note:
+ Requires buffers.max_batch_size > 0 and len(token_ids) <= max_batch_size.
+ TODO: CUDA Graph capture can be added once this path is validated.
+ """
+ seq_len = len(token_ids)
+
+ if buffers.max_batch_size == 0:
+ raise RuntimeError(
+ "Batch buffers not allocated. Call DecodeBuffers.allocate(..., max_batch_size=8)"
+ )
+ if seq_len > buffers.max_batch_size:
+ raise ValueError(
+ f"seq_len ({seq_len}) exceeds max_batch_size ({buffers.max_batch_size})"
+ )
+
+ # Get embeddings (still uses numpy - small one-time cost)
+ if not hasattr(self, "_embed_np_cache"):
+ self._embed_np_cache = self.embed_tokens.to_numpy()
+ hidden_np = self._embed_np_cache[token_ids] # [seq_len, hidden_size]
+
+ # Copy to batch hidden buffer
+ assert buffers.hidden_batch is not None
+ buffers.hidden_batch._get_native().copy_from_numpy(
+ hidden_np.astype(self._embed_np_cache.dtype)
+ )
+
+ # Use slice_rows for actual seq_len (logical batch size)
+ # slice_rows creates a zero-copy view of the first N rows
+ hidden = buffers.hidden_batch.slice_rows(seq_len)
+ residual_buf = (
+ buffers.residual_batch.slice_rows(seq_len) if buffers.residual_batch else None
+ )
+ norm_out_buf = (
+ buffers.norm_out_batch.slice_rows(seq_len) if buffers.norm_out_batch else None
+ )
+
+ # Transformer blocks
+ for block in self.blocks:
+ # Pre-norm: attn_norm(hidden) -> norm_out
+ if norm_out_buf is not None:
+ rmsnorm(hidden, block.attn_norm.weight, block.attn_norm.eps, out=norm_out_buf)
+ else:
+ norm_out_buf = block.attn_norm(hidden)
+
+ # Save residual
+ if residual_buf is not None:
+ copy_to(hidden, residual_buf)
+ else:
+ residual_buf = hidden
+
+ # Attention with fixed cache (batch) - uses existing path for now
+ # TODO: Add forward_fixed_cache_batch_zero_alloc to Attention class
+ attn_out = block.attn.forward_fixed_cache_batch(
+ norm_out_buf, start_position, context_len
+ )
+
+ # Residual connection: hidden = residual + attn_out
+ add_inplace(residual_buf, attn_out)
+ hidden = residual_buf
+
+ # MLP norm
+ if norm_out_buf is not None:
+ rmsnorm(hidden, block.mlp_norm.weight, block.mlp_norm.eps, out=norm_out_buf)
+ else:
+ norm_out_buf = block.mlp_norm(hidden)
+
+ # Save residual for MLP
+ if residual_buf is not hidden:
+ copy_to(hidden, residual_buf)
+
+ # MLP - uses existing path for now
+ # TODO: Add zero-alloc MLP path
+ mlp_out = block.mlp(norm_out_buf)
+
+ # Residual connection
+ add_inplace(residual_buf, mlp_out)
+ hidden = residual_buf
+
+ # Final norm
+ if norm_out_buf is not None:
+ rmsnorm(hidden, self.final_norm.weight, self.final_norm.eps, out=norm_out_buf)
+ return norm_out_buf
+ else:
+ return self.final_norm(hidden)
+
+ # =========================================================================
+ # Self-Speculative Decoding
+ # =========================================================================
+
+ def snapshot_kv_cache(self) -> list[tuple[np.ndarray, np.ndarray]]:
+ """Snapshot all layer KV caches to CPU memory.
+
+ Returns:
+ List of (k_cache_np, v_cache_np) tuples, one per layer.
+ Each cache is numpy array of shape [num_heads, max_seq_len, head_dim].
+ """
+ snapshot = []
+ for block in self.blocks:
+ k_np = block.attn._k_cache.to_numpy().copy()
+ v_np = block.attn._v_cache.to_numpy().copy()
+ snapshot.append((k_np, v_np))
+ return snapshot
+
+ def restore_kv_cache(self, snapshot: list[tuple[np.ndarray, np.ndarray]]) -> None:
+ """Restore all layer KV caches from CPU snapshot.
+
+ Args:
+ snapshot: List of (k_cache_np, v_cache_np) tuples from snapshot_kv_cache().
+
+ Note:
+ This method copies data into existing arrays rather than replacing them.
+ This is critical for CUDA Graph compatibility - the graph captures pointer
+ addresses, so we must preserve the existing arrays.
+ """
+ for i, block in enumerate(self.blocks):
+ k_np, v_np = snapshot[i]
+ # Copy data into existing arrays (preserves pointers for CUDA Graph)
+ k_np_typed: np.ndarray = k_np.astype(np.float16)
+ v_np_typed: np.ndarray = v_np.astype(np.float16)
+ block.attn._k_cache._get_native().copy_from_numpy(k_np_typed)
+ block.attn._v_cache._get_native().copy_from_numpy(v_np_typed)
+
+ def _draft_forward_early_layers(
+ self,
+ token_id: int,
+ position: int,
+ context_len: int,
+ num_draft_layers: int,
+ ) -> GPUArray:
+ """Forward pass through only the first N layers (draft model).
+
+ Uses the same KV cache as the full model but only updates early layers.
+ After draft is done, the early layer KV entries need to be restored
+ before running the full model verification.
+
+ Args:
+ token_id: Current token ID
+ position: Position in sequence
+ context_len: Total context length
+ num_draft_layers: Number of early layers to use as draft
+
+ Returns:
+ Hidden states [1, hidden_size] after num_draft_layers
+ """
+ # Get token embedding
+ if not hasattr(self, "_embed_np_cache"):
+ self._embed_np_cache = self.embed_tokens.to_numpy()
+ hidden_np = self._embed_np_cache[token_id : token_id + 1]
+ hidden = from_numpy(hidden_np.astype(self._embed_np_cache.dtype))
+
+ # Only run through first num_draft_layers blocks
+ for i in range(min(num_draft_layers, len(self.blocks))):
+ block = self.blocks[i]
+ # Pre-norm
+ residual = hidden
+ hidden = block.attn_norm(hidden)
+
+ # Attention with fixed cache
+ hidden = block.attn.forward_fixed_cache(hidden, position, context_len)
+ hidden = add(residual, hidden)
+
+ # MLP
+ residual = hidden
+ hidden = block.mlp_norm(hidden)
+ hidden = block.mlp(hidden)
+ hidden = add(residual, hidden)
+
+ # Note: We do NOT apply final_norm here since draft output
+ # is only used for sampling, not for precise logits
+ return hidden
+
+ def _draft_get_logits(self, hidden: GPUArray) -> GPUArray:
+ """Get logits from draft hidden states (after early layers).
+
+ This applies final_norm and then computes logits.
+ Note: The draft hidden states are from early layers, so the logits
+ may not be identical to full model logits.
+ """
+ # Apply final norm (needed for proper logits computation)
+ hidden_normed = self.final_norm(hidden)
+ return self.get_logits(hidden_normed)
+
+ def decode_step_self_speculative_lookahead(
+ self,
+ token_id: int,
+ max_draft_tokens: int = 4,
+ draft_layers: int = 8,
+ ) -> tuple[list[int], dict]:
+ """Self-speculative decode step with GPU-side lookahead KV (no CPU copies).
+
+ Uses lookahead KV cache management to avoid CPU-GPU transfers.
+
+ IMPORTANT: Before calling this method:
+ 1. Run prefill and store KV using kv_cache_prefill_gqa()
+ 2. Call set_lookahead_confirmed_pos(prefill_len) to mark prefill KV as committed
+
+ Algorithm:
+ 1. Generate draft tokens using early layers (writes to speculative positions)
+ 2. Reset lookahead, verify with full model in batch
+ 3. Accept tokens until first disagreement
+ 4. Re-run for accepted tokens to ensure correct KV
+ 5. Commit accepted tokens
+
+ Args:
+ token_id: Current token ID (the last accepted token)
+ max_draft_tokens: Maximum number of draft tokens to generate
+ draft_layers: Number of early layers to use as draft
+
+ Returns:
+ Tuple of:
+ - accepted_tokens: List of accepted token IDs
+ - stats: Dict with 'draft_count', 'accepted_count' for analysis
+ """
+ confirmed_pos = self.get_lookahead_confirmed_pos()
+
+ # === Step 1: Generate draft tokens using early layers ===
+ # Reset lookahead before draft phase
+ self.reset_lookahead_all()
+
+ draft_tokens = []
+ current_token = token_id
+
+ for i in range(max_draft_tokens):
+ pos = confirmed_pos + i
+ ctx = confirmed_pos + i + 1
+ # Forward through early layers only
+ hidden = self._draft_forward_early_layers(current_token, pos, ctx, draft_layers)
+ logits = self._draft_get_logits(hidden)
+ logits_np = logits.to_numpy()[-1]
+ next_token = int(np.argmax(logits_np))
+
+ draft_tokens.append(next_token)
+ current_token = next_token
+
+ # === Step 2: Reset and verify with full model in batch ===
+ self.reset_lookahead_all()
+
+ verify_input = [token_id] + draft_tokens[:-1]
+ verify_ctx = confirmed_pos + len(verify_input)
+
+ hidden_batch = self._decode_step_fixed_cache_batch(verify_input, confirmed_pos, verify_ctx)
+ verify_logits = self.get_logits(hidden_batch)
+ verify_logits_np = verify_logits.to_numpy()
+
+ # === Step 3: Accept/Reject tokens ===
+ accepted_tokens = []
+ for i, draft_token in enumerate(draft_tokens):
+ target_token = int(np.argmax(verify_logits_np[i]))
+
+ if target_token == draft_token:
+ accepted_tokens.append(draft_token)
+ else:
+ accepted_tokens.append(target_token)
+ break
+
+ # === Step 4: Re-run for accepted tokens if partial accept ===
+ if len(accepted_tokens) < max_draft_tokens:
+ self.reset_lookahead_all()
+ # Use CUDA Graph if available
+ use_graph = hasattr(self, "_decode_graph_ready") and self._decode_graph_ready
+ current = token_id
+ for i, acc_token in enumerate(accepted_tokens):
+ pos = confirmed_pos + i
+ ctx = confirmed_pos + i + 1
+ if use_graph:
+ self._decode_step_graph_replay(current, pos, ctx)
+ else:
+ self._decode_step_fixed_cache(current, pos, ctx)
+ current = acc_token
+
+ # === Step 5: Commit accepted tokens ===
+ self.commit_lookahead_all(len(accepted_tokens))
+
+ stats = {
+ "draft_count": len(draft_tokens),
+ "accepted_count": len(
+ [
+ t
+ for i, t in enumerate(accepted_tokens)
+ if i < len(draft_tokens) and t == draft_tokens[i]
+ ]
+ ),
+ }
+
+ return accepted_tokens, stats
+
+ # =========================================================================
+ # Lookahead KV Cache Management (GPU-side, no CPU copies)
+ # =========================================================================
+
+ def set_lookahead_confirmed_pos(self, pos: int) -> None:
+ """Set confirmed position for all layers (e.g., after prefill).
+
+ Args:
+ pos: Position where KV is finalized (tokens 0 to pos-1 are committed).
+ """
+ for block in self.blocks:
+ block.attn.set_confirmed_pos(pos)
+
+ def reset_lookahead_all(self) -> None:
+ """Reset lookahead pointer to confirmed position for all layers.
+
+ Called at the start of each Jacobi iteration. This resets the write
+ pointer without modifying KV cache - speculative positions will be
+ overwritten by the next forward pass.
+ """
+ for block in self.blocks:
+ block.attn.reset_lookahead()
+
+ def commit_lookahead_all(self, n_accepted: int) -> None:
+ """Commit accepted tokens for all layers.
+
+ Args:
+ n_accepted: Number of accepted tokens to commit.
+ """
+ for block in self.blocks:
+ block.attn.commit_lookahead(n_accepted)
+
+ def get_lookahead_confirmed_pos(self) -> int:
+ """Get current confirmed position (from first layer)."""
+ return self.blocks[0].attn.get_confirmed_pos()
+
+ # =========================================================================
+ # Jacobi Decoding
+ # =========================================================================
+
+ def _init_jacobi_guess(
+ self,
+ last_token: int,
+ position: int,
+ context_len: int,
+ n_tokens: int,
+ strategy: Literal["repeat", "ngram", "greedy"],
+ ) -> list[int]:
+ """Initialize guess tokens for Jacobi decoding.
+
+ Args:
+ last_token: The last accepted token
+ position: Current position in sequence
+ context_len: Current context length
+ n_tokens: Number of tokens to guess
+ strategy: Initialization strategy
+ - "repeat": Repeat last_token n times
+ - "ngram": Use n-gram cache (falls back to repeat if no match)
+ - "greedy": Run greedy decode to get initial guess
+
+ Returns:
+ List of n_tokens guessed token IDs
+ """
+ if strategy == "repeat":
+ return [last_token] * n_tokens
+
+ elif strategy == "ngram":
+ # N-gram cache lookup (simple implementation)
+ # Check if we have this token in recent history
+ if hasattr(self, "_ngram_cache") and last_token in self._ngram_cache:
+ cached = self._ngram_cache[last_token]
+ if len(cached) >= n_tokens:
+ return cached[:n_tokens]
+ # Fallback to repeat
+ return [last_token] * n_tokens
+
+ elif strategy == "greedy":
+ # Run greedy sequential decode to get initial guess
+ # This is expensive but gives best initial guess
+ kv_snapshot = self.snapshot_kv_cache()
+ guess = []
+ pos = position
+ ctx = context_len
+ current = last_token
+
+ for _ in range(n_tokens):
+ hidden = self._decode_step_fixed_cache(current, pos, ctx)
+ logits = self.get_logits(hidden)
+ next_token = int(np.argmax(logits.to_numpy()[-1]))
+ guess.append(next_token)
+ current = next_token
+ pos += 1
+ ctx += 1
+
+ # Restore KV cache
+ self.restore_kv_cache(kv_snapshot)
+ return guess
+
+ else:
+ raise ValueError(f"Unknown init strategy: {strategy}")
+
+ # =========================================================================
+ # Jacobi Decoding with Lookahead KV (GPU-side, no CPU copies)
+ # =========================================================================
+
+ def _init_jacobi_guess_lookahead(
+ self,
+ last_token: int,
+ n_tokens: int,
+ strategy: Literal["repeat", "ngram", "greedy"],
+ ) -> list[int]:
+ """Initialize guess tokens for Jacobi lookahead (no CPU copies).
+
+ Args:
+ last_token: The last accepted token
+ n_tokens: Number of tokens to guess
+ strategy: Initialization strategy
+ - "repeat": Repeat last_token n times
+ - "ngram": Use n-gram cache (falls back to repeat)
+ - "greedy": Run greedy decode (writes to lookahead positions)
+
+ Returns:
+ List of n_tokens guessed token IDs
+ """
+ if strategy == "repeat":
+ return [last_token] * n_tokens
+
+ elif strategy == "ngram":
+ if hasattr(self, "_ngram_cache") and last_token in self._ngram_cache:
+ cached = self._ngram_cache[last_token]
+ if len(cached) >= n_tokens:
+ return cached[:n_tokens]
+ return [last_token] * n_tokens
+
+ elif strategy == "greedy":
+ # Run greedy decode using lookahead positions
+ # This writes KV at [confirmed_pos, confirmed_pos + n_tokens)
+ confirmed_pos = self.get_lookahead_confirmed_pos()
+ guess = []
+ current = last_token
+
+ for i in range(n_tokens):
+ pos = confirmed_pos + i
+ ctx = confirmed_pos + i + 1
+ hidden = self._decode_step_fixed_cache(current, pos, ctx)
+ logits = self.get_logits(hidden)
+ next_token = int(np.argmax(logits.to_numpy()[-1]))
+ guess.append(next_token)
+ current = next_token
+
+ # Reset lookahead after greedy init (KV will be overwritten)
+ self.reset_lookahead_all()
+ return guess
+
+ else:
+ raise ValueError(f"Unknown init strategy: {strategy}")
+
+ def decode_step_jacobi_lookahead(
+ self,
+ token_id: int,
+ n_tokens: int = 8,
+ max_iter: int = 3,
+ init_strategy: Literal["repeat", "ngram", "greedy"] = "repeat",
+ ) -> tuple[list[int], dict]:
+ """Jacobi decoding step with GPU-side lookahead KV (no CPU copies).
+
+ This method uses the lookahead KV cache management to avoid all
+ CPU-GPU memory transfers during Jacobi iterations.
+
+ IMPORTANT: Before calling this method:
+ 1. Run prefill and store KV using kv_cache_prefill_gqa()
+ 2. Call set_lookahead_confirmed_pos(prefill_len) to mark prefill KV as committed
+
+ Algorithm:
+ 1. Initialize N future positions with a guess
+ 2. Reset lookahead pointer (no KV modification)
+ 3. Batch forward - writes KV at [confirmed_pos, confirmed_pos + n_tokens)
+ 4. Update guess with argmax(logits)
+ 5. Repeat until convergence or max_iter
+ 6. Commit accepted tokens by advancing confirmed_pos
+
+ Args:
+ token_id: Current token ID (the last accepted token)
+ n_tokens: Number of tokens to decode in parallel (default: 8)
+ max_iter: Maximum iterations for convergence (default: 3)
+ init_strategy: How to initialize guess tokens
+ - "repeat": Repeat last token (fast, simple)
+ - "ngram": Use n-gram cache if available
+ - "greedy": Run greedy decode first (slow but accurate)
+
+ Returns:
+ Tuple of:
+ - accepted_tokens: List of accepted token IDs
+ - stats: Dict with 'iterations', 'converged', 'accepted_count'
+ """
+ # Get confirmed position (this is our starting point)
+ confirmed_pos = self.get_lookahead_confirmed_pos()
+
+ # Initialize guess (may use lookahead positions for greedy)
+ guess = self._init_jacobi_guess_lookahead(token_id, n_tokens, init_strategy)
+
+ iterations_used = 0
+ converged = False
+ prev_guess = None
+
+ for iteration in range(max_iter):
+ iterations_used = iteration + 1
+
+ # Reset lookahead pointer (does NOT modify KV cache)
+ self.reset_lookahead_all()
+
+ # Batch forward: input [last_token, guess[0], ..., guess[n-2]]
+ # produces logits for [guess[0], guess[1], ..., guess[n-1]]
+ # Writes KV at [confirmed_pos, confirmed_pos + n_tokens)
+ input_tokens = [token_id] + guess[:-1]
+ start_pos = confirmed_pos
+ ctx_len = confirmed_pos + len(input_tokens)
+
+ hidden = self._decode_step_fixed_cache_batch(input_tokens, start_pos, ctx_len)
+ logits = self.get_logits(hidden)
+ logits_np = logits.to_numpy() # [n_tokens, vocab_size]
+
+ # Update guess with argmax
+ new_guess = [int(np.argmax(logits_np[i])) for i in range(n_tokens)]
+
+ # Check full convergence
+ if new_guess == guess:
+ converged = True
+ break
+
+ prev_guess = guess
+ guess = new_guess
+
+ # Find longest converged prefix
+ if converged:
+ accepted_tokens = guess
+ else:
+ accepted_tokens = []
+ if prev_guess is not None:
+ for i in range(n_tokens):
+ if guess[i] == prev_guess[i]:
+ accepted_tokens.append(guess[i])
+ else:
+ break
+ if len(accepted_tokens) == 0:
+ accepted_tokens = [guess[0]]
+
+ # Commit accepted tokens - this is the ONLY state change
+ # The KV for accepted tokens is already written from the last iteration
+ # We just need to run one more forward to ensure KV is correct
+ self.reset_lookahead_all()
+
+ # Re-run with just the accepted tokens to ensure KV is correct
+ if len(accepted_tokens) < n_tokens:
+ # KV may have extra speculative entries - need to overwrite with correct values
+ # Run sequential for accepted tokens only
+ # Use CUDA Graph if available
+ use_graph = hasattr(self, "_decode_graph_ready") and self._decode_graph_ready
+ current = token_id
+ for i, acc_token in enumerate(accepted_tokens):
+ pos = confirmed_pos + i
+ ctx = confirmed_pos + i + 1
+ if use_graph:
+ self._decode_step_graph_replay(current, pos, ctx)
+ else:
+ self._decode_step_fixed_cache(current, pos, ctx)
+ current = acc_token
+ # If all converged, KV is already correct from last batch forward
+
+ # Commit the accepted tokens
+ self.commit_lookahead_all(len(accepted_tokens))
+
+ # Update n-gram cache for future use
+ if not hasattr(self, "_ngram_cache"):
+ self._ngram_cache: dict[int, list[int]] = {}
+ self._ngram_cache[token_id] = accepted_tokens.copy()
+
+ stats = {
+ "iterations": iterations_used,
+ "converged": converged,
+ "accepted_count": len(accepted_tokens),
+ }
+
+ return accepted_tokens, stats
+
+
+# =============================================================================
+# Type Aliases
+# =============================================================================
+
+# GPT2Model and LlamaModel are now simple aliases for CausalTransformerModel.
+# All models use CausalTransformerModel as the single runtime type.
+GPT2Model = CausalTransformerModel
+LlamaModel = CausalTransformerModel
+
+# Legacy component aliases (import from layers module)
+RMSNorm = Norm # Use Norm with norm_type="rmsnorm"
+LayerNorm = Norm # Use Norm with norm_type="layernorm"
+LlamaAttention = Attention
+LlamaMLP = MLP
+LlamaBlock = TransformerBlock
+CausalSelfAttention = Attention
diff --git a/src/pygpukit/llm/quant.py b/src/pygpukit/llm/quant.py
new file mode 100644
index 0000000..81828ed
--- /dev/null
+++ b/src/pygpukit/llm/quant.py
@@ -0,0 +1,427 @@
+"""Quantization configuration and utilities for PyGPUkit LLM.
+
+Provides:
+- FP8QuantConfig: FP8 quantization configuration
+- QATQuantConfig: QAT (Quantization-Aware Training) configuration
+- PruningConfig: Pruning configuration
+- SparsityConfig: Sparsity pattern configuration
+- ModelOptimizationInfo: Combined optimization information
+- FP8 dequantization utilities
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+import numpy as np
+
+if TYPE_CHECKING:
+ from pygpukit.core.array import GPUArray
+ from pygpukit.llm.safetensors import SafeTensorsFile, ShardedSafeTensorsFile
+
+
+# =============================================================================
+# FP8 Quantization Support
+# =============================================================================
+
+
+@dataclass
+class FP8QuantConfig:
+ """FP8 quantization configuration from HuggingFace config.json."""
+
+ quant_method: str # "fp8"
+ fmt: str # "e4m3" or "e5m2"
+ weight_block_size: tuple[int, int] # e.g., (128, 128)
+ modules_to_not_convert: list[str] # List of module name patterns to skip
+
+ @classmethod
+ def from_config(cls, config: dict) -> FP8QuantConfig | None:
+ """Parse quantization config from HF config.json."""
+ qc = config.get("quantization_config")
+ if qc is None or qc.get("quant_method") != "fp8":
+ return None
+
+ block_size = qc.get("weight_block_size", [128, 128])
+ return cls(
+ quant_method="fp8",
+ fmt=qc.get("fmt", "e4m3"),
+ weight_block_size=(block_size[0], block_size[1]),
+ modules_to_not_convert=qc.get("modules_to_not_convert", []),
+ )
+
+
+# =============================================================================
+# QAT/QAD Quantization Support (Issue #115)
+# =============================================================================
+
+
+@dataclass
+class QATQuantConfig:
+ """QAT (Quantization-Aware Training) configuration.
+
+ Supports models trained with:
+ - NVIDIA TensorRT Model Optimizer
+ - HuggingFace Optimum
+ - PyTorch Quantization
+
+ Reference:
+ - https://nvidia.github.io/TensorRT-Model-Optimizer/
+ - https://developer.nvidia.com/blog/top-5-ai-model-optimization-techniques-for-faster-smarter-inference/
+ """
+
+ quant_method: str # "qat", "modelopt", "nvfp4", etc.
+ quant_algo: str # "FP8", "INT8", "NVFP4", "W8A8", etc.
+ group_size: int # Block/group size for quantization
+ kv_cache_quant_algo: str | None # KV cache quantization (optional)
+ exclude_modules: list[str] # Modules to skip quantization
+ producer: str | None # Tool that produced the checkpoint (e.g., "modelopt")
+ producer_version: str | None # Version of the producer tool
+
+ @classmethod
+ def from_config(cls, config: dict) -> QATQuantConfig | None:
+ """Parse QAT config from HF config.json or hf_quant_config.json."""
+ # Check for TensorRT Model Optimizer format (hf_quant_config.json style)
+ if "producer" in config and "quantization" in config:
+ producer_info = config.get("producer", {})
+ quant_info = config.get("quantization", {})
+ return cls(
+ quant_method="modelopt",
+ quant_algo=quant_info.get("quant_algo", "unknown"),
+ group_size=quant_info.get("group_size", 128),
+ kv_cache_quant_algo=quant_info.get("kv_cache_quant_algo"),
+ exclude_modules=quant_info.get("exclude_modules", []),
+ producer=producer_info.get("name"),
+ producer_version=producer_info.get("version"),
+ )
+
+ # Check for HF quantization_config with QAT method
+ qc = config.get("quantization_config")
+ if qc is None:
+ return None
+
+ quant_method = qc.get("quant_method", "")
+ # QAT methods: "qat", "awq", "gptq", etc. (exclude "fp8" which is handled separately)
+ qat_methods = {"qat", "awq", "gptq", "bnb", "modelopt"}
+ if quant_method not in qat_methods:
+ return None
+
+ return cls(
+ quant_method=quant_method,
+ quant_algo=qc.get("quant_algo", qc.get("bits", "unknown")),
+ group_size=qc.get("group_size", qc.get("block_size", 128)),
+ kv_cache_quant_algo=qc.get("kv_cache_quant_algo"),
+ exclude_modules=qc.get("modules_to_not_convert", []),
+ producer=None,
+ producer_version=None,
+ )
+
+
+# =============================================================================
+# Pruning Support (Issue #115)
+# =============================================================================
+
+
+@dataclass
+class PruningConfig:
+ """Pruning configuration for structurally smaller models.
+
+ Supports models pruned with:
+ - NVIDIA TensorRT Model Optimizer
+ - HuggingFace nn_pruning
+ - Neural Compressor
+
+ Reference:
+ - https://github.com/huggingface/nn_pruning
+ - https://github.com/NVIDIA/TensorRT-Model-Optimizer
+ """
+
+ pruning_method: str # "magnitude", "movement", "structured", "unstructured"
+ sparsity: float # Target sparsity (0.0 to 1.0)
+ pruned_heads: dict[int, list[int]] | None # Layer -> pruned head indices
+ is_structured: bool # True if structured pruning (removes entire heads/neurons)
+
+ @classmethod
+ def from_config(cls, config: dict) -> PruningConfig | None:
+ """Parse pruning config from HF config.json."""
+ # Check for pruned_heads (HuggingFace standard)
+ pruned_heads = config.get("pruned_heads")
+ if pruned_heads:
+ # Convert string keys to int if needed
+ if isinstance(pruned_heads, dict):
+ pruned_heads = {int(k): v for k, v in pruned_heads.items()}
+ return cls(
+ pruning_method="structured",
+ sparsity=0.0, # Unknown from config alone
+ pruned_heads=pruned_heads,
+ is_structured=True,
+ )
+
+ # Check for pruning_config section
+ pc = config.get("pruning_config")
+ if pc is None:
+ return None
+
+ return cls(
+ pruning_method=pc.get("pruning_type", pc.get("method", "unknown")),
+ sparsity=pc.get("target_sparsity", pc.get("sparsity", 0.0)),
+ pruned_heads=pc.get("pruned_heads"),
+ is_structured=pc.get("is_structured", pc.get("structured", False)),
+ )
+
+
+# =============================================================================
+# Sparsity Pattern Support (Issue #115)
+# =============================================================================
+
+
+@dataclass
+class SparsityConfig:
+ """Sparsity pattern configuration for sparse tensor operations.
+
+ Supports:
+ - 2:4 structured sparsity (Ampere+)
+ - Block sparsity patterns
+ - Custom sparsity masks
+
+ Reference:
+ - https://developer.nvidia.com/blog/accelerating-inference-with-sparsity-using-ampere-and-tensorrt/
+ """
+
+ pattern: str # "2:4", "4:8", "block", "unstructured"
+ block_size: tuple[int, int] | None # For block sparsity
+ density: float # Non-zero ratio (1 - sparsity)
+
+ @classmethod
+ def from_config(cls, config: dict) -> SparsityConfig | None:
+ """Parse sparsity config from HF config.json."""
+ sc = config.get("sparsity_config")
+ if sc is None:
+ # Check for sparsity in quantization_config
+ qc = config.get("quantization_config", {})
+ sparsity_pattern = qc.get("sparsity_pattern")
+ if sparsity_pattern:
+ return cls(
+ pattern=sparsity_pattern,
+ block_size=None,
+ density=1.0 - qc.get("sparsity", 0.5),
+ )
+ return None
+
+ pattern = sc.get("pattern", sc.get("sparsity_pattern", "unknown"))
+ block_size = sc.get("block_size")
+ if block_size and isinstance(block_size, list):
+ block_size = tuple(block_size)
+
+ return cls(
+ pattern=pattern,
+ block_size=block_size,
+ density=sc.get("density", 1.0 - sc.get("sparsity", 0.0)),
+ )
+
+ def is_2_4_sparse(self) -> bool:
+ """Check if this is 2:4 structured sparsity (Ampere+ TensorCore)."""
+ return self.pattern == "2:4"
+
+
+# =============================================================================
+# Model Optimization Info (Issue #115)
+# =============================================================================
+
+
+@dataclass
+class ModelOptimizationInfo:
+ """Combined optimization information for a model.
+
+ Aggregates all optimization techniques applied to the model:
+ - Quantization (FP8, QAT, etc.)
+ - Pruning (structured, unstructured)
+ - Sparsity (2:4, block)
+ """
+
+ fp8_config: FP8QuantConfig | None
+ qat_config: QATQuantConfig | None
+ pruning_config: PruningConfig | None
+ sparsity_config: SparsityConfig | None
+
+ @classmethod
+ def from_config(cls, config: dict) -> ModelOptimizationInfo:
+ """Parse all optimization configs from config.json."""
+ return cls(
+ fp8_config=FP8QuantConfig.from_config(config),
+ qat_config=QATQuantConfig.from_config(config),
+ pruning_config=PruningConfig.from_config(config),
+ sparsity_config=SparsityConfig.from_config(config),
+ )
+
+ def has_any_optimization(self) -> bool:
+ """Check if any optimization is applied."""
+ return any(
+ [
+ self.fp8_config,
+ self.qat_config,
+ self.pruning_config,
+ self.sparsity_config,
+ ]
+ )
+
+ def summary(self) -> str:
+ """Return a summary string of optimizations."""
+ parts = []
+ if self.fp8_config:
+ parts.append(f"FP8({self.fp8_config.fmt})")
+ if self.qat_config:
+ parts.append(f"QAT({self.qat_config.quant_algo})")
+ if self.pruning_config:
+ parts.append(f"Pruned({self.pruning_config.pruning_method})")
+ if self.sparsity_config:
+ parts.append(f"Sparse({self.sparsity_config.pattern})")
+ return ", ".join(parts) if parts else "None"
+
+
+# =============================================================================
+# FP8 E4M3 Conversion Utilities
+# =============================================================================
+
+# FP8 E4M3 to float32 lookup table (256 entries)
+# Format: 1 sign bit, 4 exponent bits, 3 mantissa bits
+# Special values: NaN (0x7F/0xFF), no infinity
+_FP8_E4M3_TO_F32_TABLE: np.ndarray | None = None
+
+
+def _get_fp8_e4m3_table() -> np.ndarray:
+ """Build FP8 E4M3 to float32 conversion lookup table."""
+ global _FP8_E4M3_TO_F32_TABLE
+ if _FP8_E4M3_TO_F32_TABLE is not None:
+ return _FP8_E4M3_TO_F32_TABLE
+
+ table = np.zeros(256, dtype=np.float32)
+ for i in range(256):
+ # Extract components
+ sign = (i >> 7) & 1
+ exp = (i >> 3) & 0xF # 4 exponent bits
+ mant = i & 0x7 # 3 mantissa bits
+
+ if exp == 0xF and mant == 0x7:
+ # NaN (0x7F and 0xFF)
+ table[i] = np.nan
+ elif exp == 0:
+ # Subnormal (exponent = 0)
+ # Value = (-1)^sign * 2^(-6) * (0.mantissa)
+ value = (mant / 8.0) * (2.0**-6)
+ table[i] = -value if sign else value
+ else:
+ # Normal
+ # Value = (-1)^sign * 2^(exp-7) * (1.mantissa)
+ value = (1.0 + mant / 8.0) * (2.0 ** (exp - 7))
+ table[i] = -value if sign else value
+
+ _FP8_E4M3_TO_F32_TABLE = table
+ return table
+
+
+def dequantize_fp8_e4m3_block(
+ fp8_bytes: np.ndarray,
+ scale_inv: np.ndarray,
+ block_size: tuple[int, int] = (128, 128),
+) -> np.ndarray:
+ """Dequantize FP8 E4M3 weight with block-wise scaling.
+
+ Args:
+ fp8_bytes: Raw FP8 data as uint8 array, shape [H, W]
+ scale_inv: Inverse scale factors, shape [H//block_h, W//block_w]
+ block_size: Block size for quantization (default 128x128)
+
+ Returns:
+ Dequantized float32 array, shape [H, W]
+ """
+ # Convert FP8 bytes to float32 using lookup table
+ table = _get_fp8_e4m3_table()
+ f32 = table[fp8_bytes.ravel()].reshape(fp8_bytes.shape)
+
+ # Apply block-wise scaling
+ H, W = f32.shape
+ block_h, block_w = block_size
+
+ # Ensure scale_inv is float32 for computation
+ if scale_inv.dtype != np.float32:
+ # BF16 stored as uint16 -> convert to float32
+ if scale_inv.dtype == np.uint16:
+ scale_f32 = np.empty(scale_inv.shape, dtype=np.float32)
+ scale_f32.view(np.uint32)[:] = scale_inv.astype(np.uint32) << 16
+ else:
+ scale_f32 = scale_inv.astype(np.float32)
+ else:
+ scale_f32 = scale_inv
+
+ # Apply scaling per block using broadcasting
+ num_blocks_h = H // block_h
+ num_blocks_w = W // block_w
+
+ # Reshape for vectorized block scaling
+ f32_reshaped = f32.reshape(num_blocks_h, block_h, num_blocks_w, block_w)
+ scale_expanded = scale_f32[:, np.newaxis, :, np.newaxis]
+ f32_scaled = f32_reshaped * scale_expanded
+ result = f32_scaled.reshape(H, W)
+
+ return result
+
+
+def is_fp8_weight(tensor_name: str, tensor_names: list[str]) -> bool:
+ """Check if a weight tensor has an FP8 scale tensor."""
+ scale_name = tensor_name + "_scale_inv"
+ return scale_name in tensor_names
+
+
+def load_fp8_weight_direct(
+ st: SafeTensorsFile | ShardedSafeTensorsFile,
+ weight_name: str,
+ block_size: tuple[int, int] = (128, 128),
+) -> tuple[GPUArray, GPUArray]:
+ """Load FP8 weight directly without dequantization.
+
+ Returns:
+ (weight_fp8, scale_inv) tuple:
+ - weight_fp8: [out_features, in_features] as uint8
+ - scale_inv: [out/block_h, in/block_w] as bf16
+ """
+ from pygpukit.core.factory import from_numpy
+ from pygpukit.llm.safetensors import Dtype
+
+ # Load FP8 weight as uint8
+ info = st.tensor_info(weight_name)
+ data = st.tensor_bytes(weight_name)
+ fp8_bytes = np.frombuffer(data, dtype=np.uint8).reshape(info.shape).copy()
+ weight_fp8 = from_numpy(fp8_bytes)
+
+ # Load scale_inv tensor
+ scale_name = weight_name + "_scale_inv"
+ scale_info = st.tensor_info(scale_name)
+ scale_data = st.tensor_bytes(scale_name)
+
+ # scale_inv is typically bfloat16
+ if scale_info.dtype == Dtype.BFloat16:
+ scale_inv = np.frombuffer(scale_data, dtype=np.uint16).reshape(scale_info.shape).copy()
+ else:
+ # Convert float32 to bfloat16
+ scale_f32 = np.frombuffer(scale_data, dtype=np.float32).reshape(scale_info.shape)
+ uint32_view = scale_f32.view(np.uint32)
+ scale_inv = ((uint32_view + 0x7FFF + ((uint32_view >> 16) & 1)) >> 16).astype(np.uint16)
+
+ scale_inv_gpu = from_numpy(scale_inv)
+
+ return weight_fp8, scale_inv_gpu
+
+
+__all__ = [
+ # Quantization configs
+ "FP8QuantConfig",
+ "QATQuantConfig",
+ "PruningConfig",
+ "SparsityConfig",
+ "ModelOptimizationInfo",
+ # FP8 utilities
+ "dequantize_fp8_e4m3_block",
+ "is_fp8_weight",
+ "load_fp8_weight_direct",
+]
diff --git a/src/pygpukit/llm/repack.py b/src/pygpukit/llm/repack.py
new file mode 100644
index 0000000..5e6c4de
--- /dev/null
+++ b/src/pygpukit/llm/repack.py
@@ -0,0 +1,290 @@
+"""Model weight repacking for PyGPUkit LLM.
+
+Provides memory optimization by repacking weights into contiguous GPU memory
+to fix performance regression from fragmented allocation.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import numpy as np
+
+from pygpukit.core.factory import from_numpy
+from pygpukit.llm.layers import MoELayer
+
+if TYPE_CHECKING:
+ from pygpukit.llm.model import CausalTransformerModel
+
+
+def repack_model_weights(model: CausalTransformerModel) -> None:
+ """Repack all model weights into contiguous GPU memory.
+
+ This fixes severe performance regression (7x slowdown) caused by
+ fragmented GPU memory allocation during model loading. Weights
+ allocated later end up in suboptimal memory regions.
+
+ The repacking is done in two phases:
+ 1. Convert ALL weights to numpy (freeing GPU memory)
+ 2. Reallocate ALL weights fresh in contiguous memory
+
+ Args:
+ model: CausalTransformerModel to repack in-place
+
+ Note:
+ MoE models are currently skipped (not repacked) due to different
+ weight structure. This will be addressed in a future update.
+ """
+ import gc
+
+ from pygpukit.core.array import GPUArray
+
+ # Skip repacking for MoE models (different weight structure)
+ if model.blocks and isinstance(model.blocks[0].mlp, MoELayer):
+ return
+
+ # Phase 1: Collect all weights as numpy arrays
+ numpy_cache: dict[int, dict] = {}
+ dummy_arrays: list[GPUArray] = []
+
+ # Embedding
+ embed_np = model.embed_tokens.to_numpy()
+ model.embed_tokens = None # type: ignore
+
+ # Position embedding
+ pos_embed_np = None
+ if model.position_embed is not None:
+ pos_embed_np = model.position_embed.to_numpy()
+ model.position_embed = None
+
+ # lm_head
+ lm_head_np = None
+ if model._lm_head is not None:
+ lm_head_np = model._lm_head.to_numpy()
+ model._lm_head = None
+
+ # Final norm
+ final_norm_weight_np = model.final_norm.weight.to_numpy()
+ final_norm_bias_np = None
+ if model.final_norm.bias is not None:
+ final_norm_bias_np = model.final_norm.bias.to_numpy()
+ model.final_norm.weight = None # type: ignore
+ model.final_norm.bias = None
+
+ # All blocks
+ for i, block in enumerate(model.blocks):
+ numpy_cache[i] = {}
+
+ # Attention norms
+ numpy_cache[i]["attn_norm_w"] = block.attn_norm.weight.to_numpy()
+ numpy_cache[i]["attn_norm_b"] = (
+ block.attn_norm.bias.to_numpy() if block.attn_norm.bias is not None else None
+ )
+ block.attn_norm.weight = None # type: ignore
+ block.attn_norm.bias = None
+
+ numpy_cache[i]["mlp_norm_w"] = block.mlp_norm.weight.to_numpy()
+ numpy_cache[i]["mlp_norm_b"] = (
+ block.mlp_norm.bias.to_numpy() if block.mlp_norm.bias is not None else None
+ )
+ block.mlp_norm.weight = None # type: ignore
+ block.mlp_norm.bias = None
+
+ # Attention projections
+ attn = block.attn
+ numpy_cache[i]["q_w"] = attn.q_proj.weight.to_numpy()
+ numpy_cache[i]["q_b"] = (
+ attn.q_proj.bias.to_numpy() if attn.q_proj.bias is not None else None
+ )
+ attn.q_proj.weight = None # type: ignore
+ attn.q_proj.bias = None
+ attn.q_proj._weight_t = None
+
+ numpy_cache[i]["k_w"] = attn.k_proj.weight.to_numpy()
+ numpy_cache[i]["k_b"] = (
+ attn.k_proj.bias.to_numpy() if attn.k_proj.bias is not None else None
+ )
+ attn.k_proj.weight = None # type: ignore
+ attn.k_proj.bias = None
+ attn.k_proj._weight_t = None
+
+ numpy_cache[i]["v_w"] = attn.v_proj.weight.to_numpy()
+ numpy_cache[i]["v_b"] = (
+ attn.v_proj.bias.to_numpy() if attn.v_proj.bias is not None else None
+ )
+ attn.v_proj.weight = None # type: ignore
+ attn.v_proj.bias = None
+ attn.v_proj._weight_t = None
+
+ numpy_cache[i]["o_w"] = attn.o_proj.weight.to_numpy()
+ numpy_cache[i]["o_b"] = (
+ attn.o_proj.bias.to_numpy() if attn.o_proj.bias is not None else None
+ )
+ attn.o_proj.weight = None # type: ignore
+ attn.o_proj.bias = None
+ attn.o_proj._weight_t = None
+
+ # QK norms
+ if attn.q_norm is not None:
+ numpy_cache[i]["q_norm_w"] = attn.q_norm.weight.to_numpy()
+ numpy_cache[i]["q_norm_b"] = (
+ attn.q_norm.bias.to_numpy() if attn.q_norm.bias is not None else None
+ )
+ attn.q_norm.weight = None # type: ignore
+ attn.q_norm.bias = None
+ if attn.k_norm is not None:
+ numpy_cache[i]["k_norm_w"] = attn.k_norm.weight.to_numpy()
+ numpy_cache[i]["k_norm_b"] = (
+ attn.k_norm.bias.to_numpy() if attn.k_norm.bias is not None else None
+ )
+ attn.k_norm.weight = None # type: ignore
+ attn.k_norm.bias = None
+
+ # MLP projections
+ mlp = block.mlp
+ if mlp.activation == "gelu":
+ numpy_cache[i]["fc1_w"] = mlp.fc1.weight.to_numpy()
+ numpy_cache[i]["fc1_b"] = mlp.fc1.bias.to_numpy() if mlp.fc1.bias is not None else None
+ mlp.fc1.weight = None # type: ignore
+ mlp.fc1.bias = None
+ mlp.fc1._weight_t = None
+
+ numpy_cache[i]["fc2_w"] = mlp.fc2.weight.to_numpy()
+ numpy_cache[i]["fc2_b"] = mlp.fc2.bias.to_numpy() if mlp.fc2.bias is not None else None
+ mlp.fc2.weight = None # type: ignore
+ mlp.fc2.bias = None
+ mlp.fc2._weight_t = None
+ else: # SwiGLU
+ numpy_cache[i]["gate_w"] = mlp.gate_proj.weight.to_numpy()
+ numpy_cache[i]["gate_b"] = (
+ mlp.gate_proj.bias.to_numpy() if mlp.gate_proj.bias is not None else None
+ )
+ mlp.gate_proj.weight = None # type: ignore
+ mlp.gate_proj.bias = None
+ mlp.gate_proj._weight_t = None
+
+ numpy_cache[i]["up_w"] = mlp.up_proj.weight.to_numpy()
+ numpy_cache[i]["up_b"] = (
+ mlp.up_proj.bias.to_numpy() if mlp.up_proj.bias is not None else None
+ )
+ mlp.up_proj.weight = None # type: ignore
+ mlp.up_proj.bias = None
+ mlp.up_proj._weight_t = None
+
+ numpy_cache[i]["down_w"] = mlp.down_proj.weight.to_numpy()
+ numpy_cache[i]["down_b"] = (
+ mlp.down_proj.bias.to_numpy() if mlp.down_proj.bias is not None else None
+ )
+ mlp.down_proj.weight = None # type: ignore
+ mlp.down_proj.bias = None
+ mlp.down_proj._weight_t = None
+
+ # Force garbage collection to free GPU memory
+ gc.collect()
+
+ # Allocate dummy arrays to fill the freed memory space
+ dummy_size = 1024 * 1024 * 512 # 512M elements = 1GB for FP16
+ try:
+ for _ in range(16): # Allocate ~16GB of dummy memory
+ dummy = from_numpy(np.zeros(dummy_size, dtype=np.float16))
+ dummy_arrays.append(dummy)
+ except Exception:
+ pass # Continue with whatever dummy memory we could allocate
+
+ # Phase 2: Reallocate all weights fresh (REVERSE order for memory optimization)
+ for i in reversed(range(len(model.blocks))):
+ block = model.blocks[i]
+ cache = numpy_cache[i]
+
+ # Attention norms
+ block.attn_norm.weight = from_numpy(cache["attn_norm_w"])
+ if cache["attn_norm_b"] is not None:
+ block.attn_norm.bias = from_numpy(cache["attn_norm_b"])
+
+ block.mlp_norm.weight = from_numpy(cache["mlp_norm_w"])
+ if cache["mlp_norm_b"] is not None:
+ block.mlp_norm.bias = from_numpy(cache["mlp_norm_b"])
+
+ # Attention projections
+ attn = block.attn
+ attn.q_proj.weight = from_numpy(cache["q_w"])
+ if cache["q_b"] is not None:
+ attn.q_proj.bias = from_numpy(cache["q_b"])
+
+ attn.k_proj.weight = from_numpy(cache["k_w"])
+ if cache["k_b"] is not None:
+ attn.k_proj.bias = from_numpy(cache["k_b"])
+
+ attn.v_proj.weight = from_numpy(cache["v_w"])
+ if cache["v_b"] is not None:
+ attn.v_proj.bias = from_numpy(cache["v_b"])
+
+ attn.o_proj.weight = from_numpy(cache["o_w"])
+ if cache["o_b"] is not None:
+ attn.o_proj.bias = from_numpy(cache["o_b"])
+
+ # QK norms
+ if "q_norm_w" in cache:
+ attn.q_norm.weight = from_numpy(cache["q_norm_w"])
+ if cache["q_norm_b"] is not None:
+ attn.q_norm.bias = from_numpy(cache["q_norm_b"])
+ if "k_norm_w" in cache:
+ attn.k_norm.weight = from_numpy(cache["k_norm_w"])
+ if cache["k_norm_b"] is not None:
+ attn.k_norm.bias = from_numpy(cache["k_norm_b"])
+
+ # MLP projections
+ mlp = block.mlp
+ if mlp.activation == "gelu":
+ mlp.fc1.weight = from_numpy(cache["fc1_w"])
+ if cache["fc1_b"] is not None:
+ mlp.fc1.bias = from_numpy(cache["fc1_b"])
+
+ mlp.fc2.weight = from_numpy(cache["fc2_w"])
+ if cache["fc2_b"] is not None:
+ mlp.fc2.bias = from_numpy(cache["fc2_b"])
+ else: # SwiGLU
+ mlp.gate_proj.weight = from_numpy(cache["gate_w"])
+ if cache["gate_b"] is not None:
+ mlp.gate_proj.bias = from_numpy(cache["gate_b"])
+
+ mlp.up_proj.weight = from_numpy(cache["up_w"])
+ if cache["up_b"] is not None:
+ mlp.up_proj.bias = from_numpy(cache["up_b"])
+
+ mlp.down_proj.weight = from_numpy(cache["down_w"])
+ if cache["down_b"] is not None:
+ mlp.down_proj.bias = from_numpy(cache["down_b"])
+
+ # Clear this block's cache immediately
+ del numpy_cache[i]
+
+ # Final norm
+ model.final_norm.weight = from_numpy(final_norm_weight_np)
+ if final_norm_bias_np is not None:
+ model.final_norm.bias = from_numpy(final_norm_bias_np)
+
+ # lm_head
+ if lm_head_np is not None:
+ model._lm_head = from_numpy(lm_head_np)
+
+ # Embedding and position embedding last
+ model.embed_tokens = from_numpy(embed_np)
+ del embed_np
+
+ if pos_embed_np is not None:
+ model.position_embed = from_numpy(pos_embed_np)
+ del pos_embed_np
+
+ # Clear any cached transposes
+ if hasattr(model, "_lm_head_t_cache"):
+ delattr(model, "_lm_head_t_cache")
+
+ # Free dummy arrays
+ del dummy_arrays
+ gc.collect()
+
+
+__all__ = [
+ "repack_model_weights",
+]
diff --git a/src/pygpukit/llm/safetensors.py b/src/pygpukit/llm/safetensors.py
new file mode 100644
index 0000000..b030629
--- /dev/null
+++ b/src/pygpukit/llm/safetensors.py
@@ -0,0 +1,410 @@
+"""SafeTensors file loading for PyGPUkit LLM.
+
+Provides:
+- Dtype: Tensor data type enumeration
+- TensorInfo: Metadata for a single tensor
+- SafeTensorsFile: Memory-mapped single SafeTensors file
+- ShardedSafeTensorsFile: Sharded model loader with lazy shard loading
+- load_safetensors: Unified loader function
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from pygpukit.core.backend import get_rust_module
+
+if TYPE_CHECKING:
+ from collections.abc import Sequence
+
+# Get the Rust llm module
+_rust = get_rust_module()
+_llm = _rust.llm if _rust else None
+
+
+class Dtype:
+ """Tensor data type enumeration."""
+
+ Float32 = 0
+ Float16 = 1
+ BFloat16 = 2
+ Float64 = 3
+ Float8E4M3 = 4 # FP8 E4M3 (1 sign, 4 exponent, 3 mantissa)
+ Float8E5M2 = 5 # FP8 E5M2 (1 sign, 5 exponent, 2 mantissa)
+ Int32 = 6
+ Int64 = 7
+ Int16 = 8
+ Int8 = 9
+ UInt8 = 10
+ Bool = 11
+
+ _NAMES = {
+ 0: "float32",
+ 1: "float16",
+ 2: "bfloat16",
+ 3: "float64",
+ 4: "float8_e4m3",
+ 5: "float8_e5m2",
+ 6: "int32",
+ 7: "int64",
+ 8: "int16",
+ 9: "int8",
+ 10: "uint8",
+ 11: "bool",
+ }
+
+ _SIZES = {
+ 0: 4, # float32
+ 1: 2, # float16
+ 2: 2, # bfloat16
+ 3: 8, # float64
+ 4: 1, # float8_e4m3
+ 5: 1, # float8_e5m2
+ 6: 4, # int32
+ 7: 8, # int64
+ 8: 2, # int16
+ 9: 1, # int8
+ 10: 1, # uint8
+ 11: 1, # bool
+ }
+
+ @classmethod
+ def element_size(cls, dtype: int) -> int:
+ """Get the size in bytes of a single element."""
+ return cls._SIZES.get(dtype, 0)
+
+ @classmethod
+ def name(cls, dtype: int) -> str:
+ """Get the string name of a dtype."""
+ return cls._NAMES.get(dtype, "unknown")
+
+
+class TensorInfo:
+ """Metadata for a single tensor in a safetensors file."""
+
+ def __init__(
+ self,
+ name: str,
+ dtype: int,
+ shape: Sequence[int],
+ offset: int,
+ size_bytes: int,
+ ):
+ self.name = name
+ self.dtype = dtype
+ self.shape = list(shape)
+ self.offset = offset
+ self.size_bytes = size_bytes
+
+ @property
+ def numel(self) -> int:
+ """Total number of elements."""
+ result = 1
+ for dim in self.shape:
+ result *= dim
+ return result
+
+ @property
+ def dtype_name(self) -> str:
+ """String name of the dtype."""
+ return Dtype.name(self.dtype)
+
+ def __repr__(self) -> str:
+ return (
+ f"TensorInfo(name='{self.name}', dtype={self.dtype_name}, "
+ f"shape={self.shape}, size_bytes={self.size_bytes})"
+ )
+
+
+class SafeTensorsFile:
+ """Memory-mapped SafeTensors file.
+
+ Provides efficient access to tensor metadata and data from a .safetensors file
+ using memory mapping for zero-copy data access.
+
+ Example:
+ >>> st = SafeTensorsFile("model.safetensors")
+ >>> print(st.tensor_names)
+ ['weight', 'bias']
+ >>> info = st.tensor_info('weight')
+ >>> print(info.shape, info.dtype_name)
+ [768, 768] float16
+ >>> data = st.tensor_bytes('weight')
+ """
+
+ def __init__(self, path: str):
+ """Open a safetensors file.
+
+ Args:
+ path: Path to the .safetensors file
+ """
+ if _llm is None:
+ raise RuntimeError("Rust LLM module not available")
+ self._inner = _llm.SafeTensorsFile(path)
+
+ @property
+ def tensor_names(self) -> list[str]:
+ """Get list of all tensor names."""
+ return self._inner.tensor_names
+
+ @property
+ def file_size(self) -> int:
+ """Total file size in bytes."""
+ return self._inner.file_size
+
+ @property
+ def num_tensors(self) -> int:
+ """Number of tensors in the file."""
+ return self._inner.num_tensors
+
+ def tensor_info(self, name: str) -> TensorInfo:
+ """Get metadata for a tensor by name.
+
+ Args:
+ name: Tensor name
+
+ Returns:
+ TensorInfo with dtype, shape, offset, and size
+
+ Raises:
+ KeyError: If tensor name not found
+ """
+ info = self._inner.tensor_info(name)
+ return TensorInfo(
+ name=info.name,
+ dtype=int(info.dtype),
+ shape=info.shape,
+ offset=info.offset,
+ size_bytes=info.size_bytes,
+ )
+
+ def tensor_bytes(self, name: str) -> bytes:
+ """Get raw tensor data as bytes.
+
+ Args:
+ name: Tensor name
+
+ Returns:
+ Raw bytes of the tensor data
+
+ Raises:
+ KeyError: If tensor name not found
+ """
+ return bytes(self._inner.tensor_bytes(name))
+
+ def tensor_as_f32(self, name: str):
+ """Get tensor data as numpy float32 array.
+
+ Args:
+ name: Tensor name
+
+ Returns:
+ 1D numpy array of float32 values
+
+ Raises:
+ KeyError: If tensor name not found
+ ValueError: If tensor dtype is not Float32
+ """
+ return self._inner.tensor_as_f32(name)
+
+ def tensor_data_ptr(self, name: str) -> tuple[int, int]:
+ """Get raw mmap pointer for direct GPU transfer.
+
+ Args:
+ name: Tensor name
+
+ Returns:
+ Tuple of (ptr, size_bytes) where ptr is the raw mmap address
+
+ Raises:
+ KeyError: If tensor name not found
+ """
+ return self._inner.tensor_data_ptr(name)
+
+ def __len__(self) -> int:
+ return self.num_tensors
+
+ def __contains__(self, name: str) -> bool:
+ return name in self._inner
+
+ def __repr__(self) -> str:
+ return f"SafeTensorsFile(num_tensors={self.num_tensors}, file_size={self.file_size})"
+
+
+class ShardedSafeTensorsFile:
+ """Sharded SafeTensors file loader.
+
+ Handles models split across multiple .safetensors files with an index.json.
+ Lazily opens shards on demand to minimize memory usage.
+
+ Example:
+ >>> st = ShardedSafeTensorsFile("model.safetensors.index.json")
+ >>> print(st.tensor_names[:5])
+ ['lm_head.weight', 'model.embed_tokens.weight', ...]
+ >>> info = st.tensor_info('model.embed_tokens.weight')
+ >>> data = st.tensor_bytes('model.embed_tokens.weight')
+ """
+
+ def __init__(self, index_json_path: str):
+ """Open a sharded safetensors model.
+
+ Args:
+ index_json_path: Path to model.safetensors.index.json
+ """
+ import json
+ from pathlib import Path
+
+ self._index_path = Path(index_json_path)
+ self._base_dir = self._index_path.parent
+
+ with open(index_json_path, encoding="utf-8") as f:
+ index = json.load(f)
+
+ # weight_map: { tensor_name: shard_filename }
+ self._weight_map: dict[str, str] = index.get("weight_map", {})
+ self._metadata = index.get("metadata", {})
+
+ # Lazy-loaded shard files
+ self._shards: dict[str, SafeTensorsFile] = {}
+
+ # Unique shard files
+ self._shard_files = list(set(self._weight_map.values()))
+
+ def _get_shard(self, shard_file: str) -> SafeTensorsFile:
+ """Lazily open a shard file."""
+ if shard_file not in self._shards:
+ shard_path = self._base_dir / shard_file
+ self._shards[shard_file] = SafeTensorsFile(str(shard_path))
+ return self._shards[shard_file]
+
+ @property
+ def tensor_names(self) -> list[str]:
+ """Get list of all tensor names across all shards."""
+ return list(self._weight_map.keys())
+
+ @property
+ def file_size(self) -> int:
+ """Total file size across all shards (lazy, opens all shards)."""
+ total = 0
+ for shard_file in self._shard_files:
+ total += self._get_shard(shard_file).file_size
+ return total
+
+ @property
+ def num_tensors(self) -> int:
+ """Number of tensors across all shards."""
+ return len(self._weight_map)
+
+ def tensor_info(self, name: str) -> TensorInfo:
+ """Get metadata for a tensor by name.
+
+ Args:
+ name: Tensor name
+
+ Returns:
+ TensorInfo with dtype, shape, offset, and size
+
+ Raises:
+ KeyError: If tensor name not found
+ """
+ if name not in self._weight_map:
+ raise KeyError(f"Tensor '{name}' not found")
+ shard_file = self._weight_map[name]
+ return self._get_shard(shard_file).tensor_info(name)
+
+ def tensor_bytes(self, name: str) -> bytes:
+ """Get raw tensor data as bytes.
+
+ Args:
+ name: Tensor name
+
+ Returns:
+ Raw bytes of the tensor data
+
+ Raises:
+ KeyError: If tensor name not found
+ """
+ if name not in self._weight_map:
+ raise KeyError(f"Tensor '{name}' not found")
+ shard_file = self._weight_map[name]
+ return self._get_shard(shard_file).tensor_bytes(name)
+
+ def tensor_as_f32(self, name: str):
+ """Get tensor data as numpy float32 array.
+
+ Args:
+ name: Tensor name
+
+ Returns:
+ 1D numpy array of float32 values
+
+ Raises:
+ KeyError: If tensor name not found
+ ValueError: If tensor dtype is not Float32
+ """
+ if name not in self._weight_map:
+ raise KeyError(f"Tensor '{name}' not found")
+ shard_file = self._weight_map[name]
+ return self._get_shard(shard_file).tensor_as_f32(name)
+
+ def tensor_data_ptr(self, name: str) -> tuple[int, int]:
+ """Get raw mmap pointer for direct GPU transfer.
+
+ Args:
+ name: Tensor name
+
+ Returns:
+ Tuple of (ptr, size_bytes) where ptr is the raw mmap address
+
+ Raises:
+ KeyError: If tensor name not found
+ """
+ if name not in self._weight_map:
+ raise KeyError(f"Tensor '{name}' not found")
+ shard_file = self._weight_map[name]
+ return self._get_shard(shard_file).tensor_data_ptr(name)
+
+ def __len__(self) -> int:
+ return self.num_tensors
+
+ def __contains__(self, name: str) -> bool:
+ return name in self._weight_map
+
+ def __repr__(self) -> str:
+ return (
+ f"ShardedSafeTensorsFile(num_tensors={self.num_tensors}, "
+ f"num_shards={len(self._shard_files)})"
+ )
+
+
+def load_safetensors(path: str) -> SafeTensorsFile | ShardedSafeTensorsFile:
+ """Load a safetensors file (single or sharded).
+
+ Automatically detects sharded models by .index.json extension.
+
+ Args:
+ path: Path to .safetensors file or .safetensors.index.json
+
+ Returns:
+ SafeTensorsFile or ShardedSafeTensorsFile for accessing tensor data
+
+ Example:
+ # Single file
+ st = load_safetensors("model.safetensors")
+
+ # Sharded model
+ st = load_safetensors("model.safetensors.index.json")
+ """
+ if path.endswith(".index.json"):
+ return ShardedSafeTensorsFile(path)
+ else:
+ return SafeTensorsFile(path)
+
+
+__all__ = [
+ "Dtype",
+ "TensorInfo",
+ "SafeTensorsFile",
+ "ShardedSafeTensorsFile",
+ "load_safetensors",
+]
diff --git a/src/pygpukit/llm/tokenizer.py b/src/pygpukit/llm/tokenizer.py
new file mode 100644
index 0000000..ea02c5d
--- /dev/null
+++ b/src/pygpukit/llm/tokenizer.py
@@ -0,0 +1,152 @@
+"""BPE Tokenizer for PyGPUkit LLM.
+
+**Note:** This tokenizer is experimental and intended for demos/testing only.
+For production use, we recommend HuggingFace tokenizers:
+- https://github.com/huggingface/tokenizers
+- pip install tokenizers
+
+PyGPUkit's core responsibility is GPU execution, not tokenization.
+The model API expects token IDs as input - use your preferred tokenizer
+to convert text to token IDs before passing to PyGPUkit models.
+"""
+
+from __future__ import annotations
+
+from pygpukit.core.backend import get_rust_module
+
+# Get the Rust llm module
+_rust = get_rust_module()
+_llm = _rust.llm if _rust else None
+
+
+class Tokenizer:
+ """BPE Tokenizer for GPT-2 style models.
+
+ **EXPERIMENTAL: This tokenizer is intended for demos and testing only.**
+
+ For production use, we recommend HuggingFace tokenizers:
+ - https://github.com/huggingface/tokenizers
+ - pip install tokenizers
+
+ PyGPUkit's core responsibility is GPU execution, not tokenization.
+ The model API expects token IDs as input - use your preferred tokenizer
+ to convert text to token IDs before passing to PyGPUkit models.
+
+ Limitations:
+ - Only supports a subset of HuggingFace tokenizer.json formats
+ - May not work with all models (e.g., Qwen3 uses unsupported format)
+ - No chat template support
+ - No special token handling beyond BOS/EOS/PAD
+
+ Example:
+ >>> # For demos/testing only
+ >>> tok = Tokenizer("tokenizer.json")
+ >>> ids = tok.encode("Hello, world!")
+ >>> text = tok.decode(ids)
+
+ >>> # For production, use HuggingFace tokenizers:
+ >>> from tokenizers import Tokenizer as HFTokenizer
+ >>> hf_tok = HFTokenizer.from_file("tokenizer.json")
+ >>> ids = hf_tok.encode("Hello, world!").ids
+ """
+
+ def __init__(self, path: str):
+ """Load tokenizer from tokenizer.json file.
+
+ Args:
+ path: Path to the tokenizer.json file
+ """
+ if _llm is None:
+ raise RuntimeError("Rust LLM module not available")
+ self._inner = _llm.Tokenizer(path)
+
+ @classmethod
+ def from_json(cls, json_str: str) -> Tokenizer:
+ """Load tokenizer from JSON string.
+
+ Args:
+ json_str: JSON string containing tokenizer config
+
+ Returns:
+ Tokenizer instance
+ """
+ if _llm is None:
+ raise RuntimeError("Rust LLM module not available")
+ instance = cls.__new__(cls)
+ instance._inner = _llm.Tokenizer.from_json(json_str)
+ return instance
+
+ @property
+ def vocab_size(self) -> int:
+ """Get vocabulary size."""
+ return self._inner.vocab_size
+
+ @property
+ def bos_token_id(self) -> int | None:
+ """Get BOS (beginning of sequence) token ID if available."""
+ return self._inner.bos_token_id
+
+ @property
+ def eos_token_id(self) -> int | None:
+ """Get EOS (end of sequence) token ID if available."""
+ return self._inner.eos_token_id
+
+ @property
+ def pad_token_id(self) -> int | None:
+ """Get PAD token ID if available."""
+ return self._inner.pad_token_id
+
+ def encode(self, text: str) -> list[int]:
+ """Encode text to token IDs.
+
+ Args:
+ text: Input text to encode
+
+ Returns:
+ List of token IDs
+ """
+ return list(self._inner.encode(text))
+
+ def decode(self, token_ids: list[int]) -> str:
+ """Decode token IDs to text.
+
+ Args:
+ token_ids: List of token IDs
+
+ Returns:
+ Decoded text string
+ """
+ return self._inner.decode(token_ids)
+
+ def id_to_token(self, token_id: int) -> str | None:
+ """Get token string for an ID.
+
+ Args:
+ token_id: Token ID
+
+ Returns:
+ Token string if ID is valid, None otherwise
+ """
+ return self._inner.id_to_token(token_id)
+
+ def token_to_id(self, token: str) -> int | None:
+ """Get ID for a token string.
+
+ Args:
+ token: Token string
+
+ Returns:
+ Token ID if token exists, None otherwise
+ """
+ return self._inner.token_to_id(token)
+
+ def __len__(self) -> int:
+ return self.vocab_size
+
+ def __repr__(self) -> str:
+ return f"Tokenizer(vocab_size={self.vocab_size})"
+
+
+__all__ = [
+ "Tokenizer",
+]
diff --git a/src/pygpukit/ops/audio.py b/src/pygpukit/ops/audio.py
deleted file mode 100644
index aba3381..0000000
--- a/src/pygpukit/ops/audio.py
+++ /dev/null
@@ -1,1827 +0,0 @@
-"""GPU Audio Processing Operations.
-
-This module provides GPU-accelerated audio processing for ASR/Whisper preprocessing:
-- PCM to float conversion
-- Stereo to mono conversion
-- Peak/RMS normalization
-- Resampling (48kHz -> 16kHz)
-
-Example:
- >>> import numpy as np
- >>> import pygpukit as gk
- >>> from pygpukit.ops import audio
- >>>
- >>> # Load PCM samples (int16)
- >>> pcm = np.array([0, 16384, -16384, 32767], dtype=np.int16)
- >>> buf = audio.from_pcm(pcm, sample_rate=48000)
- >>>
- >>> # Process audio
- >>> buf = buf.to_mono().resample(16000).normalize()
- >>> result = buf.data.to_numpy()
-"""
-
-from __future__ import annotations
-
-from dataclasses import dataclass
-
-import numpy as np
-
-from pygpukit.core import GPUArray
-from pygpukit.core import from_numpy as core_from_numpy
-from pygpukit.core.dtypes import float32, int16
-
-
-def _get_native():
- """Get the native module."""
- try:
- from pygpukit._native_loader import get_native_module
-
- return get_native_module()
- except ImportError:
- from pygpukit import _pygpukit_native
-
- return _pygpukit_native
-
-
-@dataclass
-class AudioBuffer:
- """GPU audio buffer with metadata.
-
- Attributes:
- data: GPUArray containing audio samples (float32)
- sample_rate: Sample rate in Hz
- channels: Number of channels (1=mono, 2=stereo)
- """
-
- data: GPUArray
- sample_rate: int
- channels: int
-
- def to_mono(self) -> AudioBuffer:
- """Convert stereo audio to mono.
-
- Returns:
- New AudioBuffer with mono audio (channels=1)
-
- Raises:
- ValueError: If already mono
- """
- if self.channels == 1:
- return self
-
- if self.channels != 2:
- raise ValueError(f"to_mono only supports stereo (2 channels), got {self.channels}")
-
- native = _get_native()
- mono_data = native.audio_stereo_to_mono(self.data._get_native())
-
- return AudioBuffer(
- data=GPUArray._wrap_native(mono_data),
- sample_rate=self.sample_rate,
- channels=1,
- )
-
- def resample(self, target_rate: int) -> AudioBuffer:
- """Resample audio to target sample rate.
-
- Currently supports:
- - 48000 -> 16000 (3:1 decimation for Whisper)
-
- Args:
- target_rate: Target sample rate in Hz
-
- Returns:
- New AudioBuffer with resampled audio
-
- Raises:
- ValueError: If sample rate conversion is not supported
- """
- if self.sample_rate == target_rate:
- return self
-
- native = _get_native()
- resampled = native.audio_resample(self.data._get_native(), self.sample_rate, target_rate)
-
- return AudioBuffer(
- data=GPUArray._wrap_native(resampled),
- sample_rate=target_rate,
- channels=self.channels,
- )
-
- def normalize(self, mode: str = "peak", target_db: float = -20.0) -> AudioBuffer:
- """Normalize audio level.
-
- Args:
- mode: Normalization mode ("peak" or "rms")
- target_db: Target level in dB (only used for RMS mode)
-
- Returns:
- Self (in-place normalization)
-
- Raises:
- ValueError: If mode is not "peak" or "rms"
- """
- native = _get_native()
-
- if mode == "peak":
- native.audio_normalize_peak(self.data._get_native())
- elif mode == "rms":
- native.audio_normalize_rms(self.data._get_native(), target_db)
- else:
- raise ValueError(f"Unknown normalization mode: {mode}. Use 'peak' or 'rms'.")
-
- return self
-
- def to_numpy(self) -> np.ndarray:
- """Convert audio data to NumPy array.
-
- Returns:
- NumPy array of float32 samples
- """
- return self.data.to_numpy()
-
- def __repr__(self) -> str:
- return (
- f"AudioBuffer(samples={self.data.shape[0]}, "
- f"sample_rate={self.sample_rate}, channels={self.channels})"
- )
-
-
-def from_pcm(
- samples: np.ndarray | GPUArray,
- sample_rate: int,
- channels: int = 1,
-) -> AudioBuffer:
- """Create AudioBuffer from PCM samples.
-
- Args:
- samples: PCM samples as int16 or float32 array
- sample_rate: Sample rate in Hz (e.g., 48000, 16000)
- channels: Number of channels (1=mono, 2=stereo)
-
- Returns:
- AudioBuffer with audio data on GPU
-
- Example:
- >>> pcm = np.array([0, 16384, -16384], dtype=np.int16)
- >>> buf = from_pcm(pcm, sample_rate=48000)
- """
- native = _get_native()
-
- # Convert to GPUArray if needed
- if isinstance(samples, np.ndarray):
- gpu_samples = core_from_numpy(samples)
- else:
- gpu_samples = samples
-
- # Convert int16 PCM to float32
- if gpu_samples.dtype == int16:
- float_data = native.audio_pcm_to_float32(gpu_samples._get_native())
- gpu_data = GPUArray._wrap_native(float_data)
- elif gpu_samples.dtype == float32:
- # Already float32, just use as-is
- gpu_data = gpu_samples
- else:
- raise ValueError(f"Unsupported dtype: {gpu_samples.dtype}. Use int16 or float32.")
-
- return AudioBuffer(
- data=gpu_data,
- sample_rate=sample_rate,
- channels=channels,
- )
-
-
-class AudioRingBuffer:
- """GPU-side ring buffer for streaming audio.
-
- Provides efficient circular buffer operations for real-time audio processing.
-
- Args:
- capacity: Buffer capacity in samples
- sample_rate: Sample rate in Hz (for metadata)
-
- Example:
- >>> ring = AudioRingBuffer(capacity=48000, sample_rate=16000) # 3 sec buffer
- >>> ring.write(chunk1)
- >>> ring.write(chunk2)
- >>> window = ring.read(16000) # Read 1 second
- """
-
- def __init__(self, capacity: int, sample_rate: int = 16000):
- from pygpukit.core import zeros
-
- self._buffer = zeros((capacity,), dtype="float32")
- self._capacity = capacity
- self._sample_rate = sample_rate
- self._write_pos = 0
- self._samples_written = 0
-
- @property
- def capacity(self) -> int:
- """Buffer capacity in samples."""
- return self._capacity
-
- @property
- def sample_rate(self) -> int:
- """Sample rate in Hz."""
- return self._sample_rate
-
- @property
- def samples_available(self) -> int:
- """Number of samples available for reading."""
- return min(self._samples_written, self._capacity)
-
- @property
- def duration_available(self) -> float:
- """Duration of available audio in seconds."""
- return self.samples_available / self._sample_rate
-
- def write(self, samples: np.ndarray | GPUArray) -> int:
- """Write samples to the ring buffer.
-
- Args:
- samples: Audio samples to write (float32)
-
- Returns:
- Number of samples written
- """
- native = _get_native()
-
- # Convert to GPUArray if needed
- if isinstance(samples, np.ndarray):
- gpu_samples = core_from_numpy(samples.astype(np.float32))
- else:
- gpu_samples = samples
-
- num_samples = gpu_samples.shape[0]
-
- # Write to ring buffer
- native.audio_ring_buffer_write(
- gpu_samples._get_native(),
- self._buffer._get_native(),
- self._write_pos,
- )
-
- # Update write position
- self._write_pos = (self._write_pos + num_samples) % self._capacity
- self._samples_written += num_samples
-
- return num_samples
-
- def read(self, num_samples: int, offset: int = 0) -> GPUArray:
- """Read samples from the ring buffer.
-
- Args:
- num_samples: Number of samples to read
- offset: Offset from current read position (0 = most recent)
-
- Returns:
- GPUArray of audio samples
- """
- native = _get_native()
-
- # Calculate read position (read from oldest available)
- if self._samples_written <= self._capacity:
- read_pos = offset
- else:
- read_pos = (self._write_pos + offset) % self._capacity
-
- result = native.audio_ring_buffer_read(
- self._buffer._get_native(),
- read_pos,
- num_samples,
- )
-
- return GPUArray._wrap_native(result)
-
- def clear(self) -> None:
- """Clear the buffer."""
- from pygpukit.core import zeros
-
- self._buffer = zeros((self._capacity,), dtype="float32")
- self._write_pos = 0
- self._samples_written = 0
-
- def __repr__(self) -> str:
- return (
- f"AudioRingBuffer(capacity={self._capacity}, "
- f"sample_rate={self._sample_rate}, "
- f"available={self.samples_available})"
- )
-
-
-class AudioStream:
- """High-level streaming audio processor.
-
- Provides chunked processing with windowing for smooth transitions.
- Suitable for real-time ASR preprocessing.
-
- Args:
- chunk_size: Processing chunk size in samples (default: 480 = 30ms @ 16kHz)
- hop_size: Hop size between chunks (default: chunk_size // 2 for 50% overlap)
- sample_rate: Sample rate in Hz
- buffer_duration: Ring buffer duration in seconds
-
- Example:
- >>> stream = AudioStream(chunk_size=480, sample_rate=16000)
- >>> for pcm_chunk in audio_source:
- ... stream.push(pcm_chunk)
- ... if stream.has_chunk():
- ... chunk = stream.pop_chunk()
- ... # Process chunk for ASR
- """
-
- def __init__(
- self,
- chunk_size: int = 480,
- hop_size: int | None = None,
- sample_rate: int = 16000,
- buffer_duration: float = 30.0,
- ):
- self._chunk_size = chunk_size
- self._hop_size = hop_size if hop_size is not None else chunk_size // 2
- self._sample_rate = sample_rate
-
- # Ring buffer for incoming audio
- buffer_samples = int(buffer_duration * sample_rate)
- self._ring_buffer = AudioRingBuffer(buffer_samples, sample_rate)
-
- # Track chunk position
- self._chunks_processed = 0
-
- @property
- def chunk_size(self) -> int:
- """Chunk size in samples."""
- return self._chunk_size
-
- @property
- def hop_size(self) -> int:
- """Hop size in samples."""
- return self._hop_size
-
- @property
- def sample_rate(self) -> int:
- """Sample rate in Hz."""
- return self._sample_rate
-
- def push(self, samples: np.ndarray | GPUArray) -> int:
- """Push audio samples to the stream.
-
- Args:
- samples: Audio samples (float32)
-
- Returns:
- Number of samples pushed
- """
- return self._ring_buffer.write(samples)
-
- def has_chunk(self) -> bool:
- """Check if a full chunk is available."""
- required = self._chunks_processed * self._hop_size + self._chunk_size
- return self._ring_buffer._samples_written >= required
-
- def pop_chunk(self, apply_window: bool = True) -> GPUArray:
- """Pop the next chunk from the stream.
-
- Args:
- apply_window: Whether to apply Hann window (default True)
-
- Returns:
- GPUArray containing the chunk
-
- Raises:
- RuntimeError: If no chunk is available
- """
- if not self.has_chunk():
- raise RuntimeError("No chunk available. Call has_chunk() first.")
-
- native = _get_native()
-
- # Calculate read offset
- read_offset = self._chunks_processed * self._hop_size
-
- # Read chunk from ring buffer
- chunk = self._ring_buffer.read(self._chunk_size, read_offset)
-
- # Apply window if requested
- if apply_window:
- native.audio_apply_hann_window(chunk._get_native())
-
- self._chunks_processed += 1
- return chunk
-
- def reset(self) -> None:
- """Reset the stream state."""
- self._ring_buffer.clear()
- self._chunks_processed = 0
-
- @property
- def chunks_available(self) -> int:
- """Number of complete chunks available."""
- if self._ring_buffer._samples_written < self._chunk_size:
- return 0
- available = self._ring_buffer._samples_written - self._chunk_size
- return available // self._hop_size + 1 - self._chunks_processed
-
- def __repr__(self) -> str:
- return (
- f"AudioStream(chunk_size={self._chunk_size}, "
- f"hop_size={self._hop_size}, "
- f"sample_rate={self._sample_rate}, "
- f"chunks_available={self.chunks_available})"
- )
-
-
-@dataclass
-class SpeechSegment:
- """Represents a detected speech segment.
-
- Attributes:
- start_sample: Start sample index
- end_sample: End sample index
- start_time: Start time in seconds
- end_time: End time in seconds
- """
-
- start_sample: int
- end_sample: int
- start_time: float
- end_time: float
-
-
-class VAD:
- """GPU-accelerated Voice Activity Detection.
-
- Detects speech segments in audio using energy and zero-crossing rate features.
- Supports adaptive thresholding and hangover smoothing for robust detection.
-
- Args:
- sample_rate: Audio sample rate in Hz (default: 16000)
- frame_ms: Frame duration in milliseconds (default: 20)
- hop_ms: Hop duration in milliseconds (default: 10)
- energy_threshold: Energy threshold for speech (default: auto)
- hangover_ms: Hangover duration in milliseconds (default: 100)
-
- Example:
- >>> vad = VAD(sample_rate=16000)
- >>> segments = vad.detect(audio_buffer)
- >>> for seg in segments:
- ... print(f"Speech: {seg.start_time:.2f}s - {seg.end_time:.2f}s")
- """
-
- def __init__(
- self,
- sample_rate: int = 16000,
- frame_ms: float = 20.0,
- hop_ms: float = 10.0,
- energy_threshold: float | None = None,
- hangover_ms: float = 100.0,
- zcr_low: float = 0.02,
- zcr_high: float = 0.25,
- ):
- self._sample_rate = sample_rate
- self._frame_size = int(frame_ms * sample_rate / 1000)
- self._hop_size = int(hop_ms * sample_rate / 1000)
- self._energy_threshold = energy_threshold
- self._hangover_frames = int(hangover_ms / hop_ms)
- self._zcr_low = zcr_low
- self._zcr_high = zcr_high
-
- # Adaptive threshold multiplier (above noise floor)
- self._adaptive_multiplier = 3.0
-
- @property
- def sample_rate(self) -> int:
- """Sample rate in Hz."""
- return self._sample_rate
-
- @property
- def frame_size(self) -> int:
- """Frame size in samples."""
- return self._frame_size
-
- @property
- def hop_size(self) -> int:
- """Hop size in samples."""
- return self._hop_size
-
- def detect(self, audio: AudioBuffer | GPUArray) -> list[SpeechSegment]:
- """Detect speech segments in audio.
-
- Args:
- audio: AudioBuffer or GPUArray of float32 samples
-
- Returns:
- List of SpeechSegment objects representing detected speech regions
- """
- native = _get_native()
-
- # Get audio data
- if isinstance(audio, AudioBuffer):
- data = audio.data
- else:
- data = audio
-
- # Compute frame features
- energy = native.vad_compute_energy(data._get_native(), self._frame_size, self._hop_size)
- zcr = native.vad_compute_zcr(data._get_native(), self._frame_size, self._hop_size)
-
- energy_gpu = GPUArray._wrap_native(energy)
- zcr_gpu = GPUArray._wrap_native(zcr)
-
- # Determine energy threshold
- if self._energy_threshold is not None:
- threshold = self._energy_threshold
- else:
- # Adaptive threshold: multiplier * noise_floor
- noise_floor = native.vad_compute_noise_floor(energy)
- threshold = max(noise_floor * self._adaptive_multiplier, 0.01)
-
- # VAD decision
- vad_flags = native.vad_decide(
- energy_gpu._get_native(),
- zcr_gpu._get_native(),
- threshold,
- self._zcr_low,
- self._zcr_high,
- )
- vad_flags_gpu = GPUArray._wrap_native(vad_flags)
-
- # Apply hangover smoothing
- if self._hangover_frames > 0:
- smoothed = native.vad_apply_hangover(vad_flags_gpu._get_native(), self._hangover_frames)
- vad_flags_gpu = GPUArray._wrap_native(smoothed)
-
- # Convert to segments
- return self._flags_to_segments(vad_flags_gpu)
-
- def _flags_to_segments(self, vad_flags: GPUArray) -> list[SpeechSegment]:
- """Convert frame-level VAD flags to speech segments."""
- flags: np.ndarray = vad_flags.to_numpy().astype(int)
-
- segments: list[SpeechSegment] = []
- in_speech = False
- start_frame = 0
-
- for i, flag in enumerate(flags):
- if flag == 1 and not in_speech:
- # Speech start
- in_speech = True
- start_frame = i
- elif flag == 0 and in_speech:
- # Speech end
- in_speech = False
- segments.append(self._create_segment(start_frame, i))
-
- # Handle case where speech continues to end
- if in_speech:
- segments.append(self._create_segment(start_frame, len(flags)))
-
- return segments
-
- def _create_segment(self, start_frame: int, end_frame: int) -> SpeechSegment:
- """Create a SpeechSegment from frame indices."""
- start_sample = start_frame * self._hop_size
- end_sample = end_frame * self._hop_size + self._frame_size
-
- return SpeechSegment(
- start_sample=start_sample,
- end_sample=end_sample,
- start_time=start_sample / self._sample_rate,
- end_time=end_sample / self._sample_rate,
- )
-
- def get_frame_features(self, audio: AudioBuffer | GPUArray) -> tuple[GPUArray, GPUArray]:
- """Get raw frame features (energy and ZCR) for analysis.
-
- Args:
- audio: AudioBuffer or GPUArray of float32 samples
-
- Returns:
- Tuple of (energy, zcr) GPUArrays
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- data = audio.data
- else:
- data = audio
-
- energy = native.vad_compute_energy(data._get_native(), self._frame_size, self._hop_size)
- zcr = native.vad_compute_zcr(data._get_native(), self._frame_size, self._hop_size)
-
- return GPUArray._wrap_native(energy), GPUArray._wrap_native(zcr)
-
- def __repr__(self) -> str:
- return (
- f"VAD(sample_rate={self._sample_rate}, "
- f"frame_size={self._frame_size}, "
- f"hop_size={self._hop_size}, "
- f"hangover_frames={self._hangover_frames})"
- )
-
-
-# =============================================================================
-# Audio Preprocessing Functions
-# =============================================================================
-
-
-def preemphasis(audio: AudioBuffer | GPUArray, alpha: float = 0.97) -> AudioBuffer | GPUArray:
- """Apply pre-emphasis filter to emphasize high-frequency components.
-
- Pre-emphasis is commonly used in speech processing to boost high frequencies
- that are typically attenuated during recording.
-
- Formula: y[n] = x[n] - alpha * x[n-1]
-
- Args:
- audio: AudioBuffer or GPUArray of float32 samples
- alpha: Pre-emphasis coefficient (default 0.97)
-
- Returns:
- Same type as input (modified in-place)
-
- Example:
- >>> buf = from_pcm(pcm_data, sample_rate=16000)
- >>> preemphasis(buf, alpha=0.97)
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- native.audio_preemphasis(audio.data._get_native(), alpha)
- return audio
- else:
- native.audio_preemphasis(audio._get_native(), alpha)
- return audio
-
-
-def deemphasis(audio: AudioBuffer | GPUArray, alpha: float = 0.97) -> AudioBuffer | GPUArray:
- """Apply de-emphasis filter (inverse of pre-emphasis).
-
- Used to restore the original spectral balance after pre-emphasis.
-
- Formula: y[n] = x[n] + alpha * y[n-1]
-
- Args:
- audio: AudioBuffer or GPUArray of float32 samples
- alpha: De-emphasis coefficient (default 0.97)
-
- Returns:
- Same type as input (modified in-place)
-
- Example:
- >>> buf = preemphasis(buf)
- >>> # ... processing ...
- >>> deemphasis(buf) # Restore original balance
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- native.audio_deemphasis(audio.data._get_native(), alpha)
- return audio
- else:
- native.audio_deemphasis(audio._get_native(), alpha)
- return audio
-
-
-def remove_dc(audio: AudioBuffer | GPUArray) -> AudioBuffer | GPUArray:
- """Remove DC offset from audio signal.
-
- Subtracts the mean value from all samples, centering the signal at zero.
- This is a simple but effective way to remove DC bias.
-
- Args:
- audio: AudioBuffer or GPUArray of float32 samples
-
- Returns:
- Same type as input (modified in-place)
-
- Example:
- >>> buf = from_pcm(pcm_data, sample_rate=16000)
- >>> remove_dc(buf)
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- native.audio_remove_dc(audio.data._get_native())
- return audio
- else:
- native.audio_remove_dc(audio._get_native())
- return audio
-
-
-def highpass_filter(
- audio: AudioBuffer | GPUArray,
- cutoff_hz: float = 20.0,
- sample_rate: int | None = None,
-) -> AudioBuffer | GPUArray:
- """Apply high-pass filter for DC removal.
-
- Uses a single-pole IIR high-pass filter, which is more effective than
- simple mean subtraction for removing low-frequency noise.
-
- Args:
- audio: AudioBuffer or GPUArray of float32 samples
- cutoff_hz: Cutoff frequency in Hz (default 20.0)
- sample_rate: Sample rate in Hz (auto-detected from AudioBuffer)
-
- Returns:
- Same type as input (modified in-place)
-
- Example:
- >>> buf = from_pcm(pcm_data, sample_rate=16000)
- >>> highpass_filter(buf, cutoff_hz=50.0) # Remove hum
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- sr = sample_rate if sample_rate is not None else audio.sample_rate
- native.audio_highpass_filter(audio.data._get_native(), cutoff_hz, sr)
- return audio
- else:
- sr = sample_rate if sample_rate is not None else 16000
- native.audio_highpass_filter(audio._get_native(), cutoff_hz, sr)
- return audio
-
-
-def noise_gate(audio: AudioBuffer | GPUArray, threshold: float = 0.01) -> AudioBuffer | GPUArray:
- """Apply simple noise gate.
-
- Zeros samples with absolute value below threshold. This is a hard gate
- that completely silences quiet sections.
-
- Args:
- audio: AudioBuffer or GPUArray of float32 samples
- threshold: Amplitude threshold (default 0.01)
-
- Returns:
- Same type as input (modified in-place)
-
- Example:
- >>> buf = from_pcm(pcm_data, sample_rate=16000)
- >>> noise_gate(buf, threshold=0.02)
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- native.audio_noise_gate(audio.data._get_native(), threshold)
- return audio
- else:
- native.audio_noise_gate(audio._get_native(), threshold)
- return audio
-
-
-def spectral_gate(
- audio: AudioBuffer | GPUArray,
- threshold: float = 0.01,
- attack_samples: int = 64,
- release_samples: int = 256,
-) -> AudioBuffer | GPUArray:
- """Apply spectral gate for noise reduction.
-
- A softer noise gate that attenuates (rather than silences) quiet sections
- based on short-term frame energy. Provides smoother transitions than
- a hard noise gate.
-
- Args:
- audio: AudioBuffer or GPUArray of float32 samples
- threshold: Energy threshold (linear scale, default 0.01)
- attack_samples: Frame size for energy computation (default 64)
- release_samples: Smoothing release in samples (default 256)
-
- Returns:
- Same type as input (modified in-place)
-
- Example:
- >>> buf = from_pcm(pcm_data, sample_rate=16000)
- >>> spectral_gate(buf, threshold=0.005) # Subtle noise reduction
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- native.audio_spectral_gate(
- audio.data._get_native(), threshold, attack_samples, release_samples
- )
- return audio
- else:
- native.audio_spectral_gate(audio._get_native(), threshold, attack_samples, release_samples)
- return audio
-
-
-def compute_short_term_energy(audio: AudioBuffer | GPUArray, frame_size: int = 256) -> GPUArray:
- """Compute short-term energy for analysis or adaptive processing.
-
- Divides the audio into non-overlapping frames and computes the mean
- energy (sum of squares / frame_size) for each frame.
-
- Args:
- audio: AudioBuffer or GPUArray of float32 samples
- frame_size: Frame size in samples (default 256)
-
- Returns:
- GPUArray of frame energies
-
- Example:
- >>> buf = from_pcm(pcm_data, sample_rate=16000)
- >>> energy = compute_short_term_energy(buf, frame_size=320) # 20ms @ 16kHz
- >>> print(f"Max energy: {energy.to_numpy().max():.4f}")
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- data = audio.data
- else:
- data = audio
-
- result = native.audio_compute_short_term_energy(data._get_native(), frame_size)
- return GPUArray._wrap_native(result)
-
-
-# =============================================================================
-# Spectral Processing Functions
-# =============================================================================
-
-
-def stft(
- audio: AudioBuffer | GPUArray,
- n_fft: int = 512,
- hop_length: int = 160,
- win_length: int = -1,
- center: bool = True,
-) -> GPUArray:
- """Compute Short-Time Fourier Transform (STFT).
-
- Uses a custom Radix-2 FFT implementation (no cuFFT dependency).
-
- Args:
- audio: AudioBuffer or GPUArray of float32 samples
- n_fft: FFT size (must be power of 2, default 512)
- hop_length: Hop size (default 160)
- win_length: Window length (default n_fft)
- center: Whether to pad input with reflection (default True)
-
- Returns:
- Complex STFT output [n_frames, n_fft/2+1, 2] (real, imag)
-
- Example:
- >>> buf = from_pcm(pcm_data, sample_rate=16000)
- >>> stft_out = stft(buf, n_fft=512, hop_length=160)
- >>> print(f"STFT shape: {stft_out.shape}") # [n_frames, 257, 2]
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- data = audio.data
- else:
- data = audio
-
- result = native.audio_stft(data._get_native(), n_fft, hop_length, win_length, center)
- return GPUArray._wrap_native(result)
-
-
-def power_spectrum(stft_output: GPUArray) -> GPUArray:
- """Compute power spectrogram from STFT output.
-
- power = real^2 + imag^2
-
- Args:
- stft_output: STFT output [n_frames, n_freq, 2]
-
- Returns:
- Power spectrogram [n_frames, n_freq]
-
- Example:
- >>> stft_out = stft(buf, n_fft=512)
- >>> power = power_spectrum(stft_out)
- """
- native = _get_native()
- result = native.audio_power_spectrum(stft_output._get_native())
- return GPUArray._wrap_native(result)
-
-
-def magnitude_spectrum(stft_output: GPUArray) -> GPUArray:
- """Compute magnitude spectrogram from STFT output.
-
- magnitude = sqrt(real^2 + imag^2)
-
- Args:
- stft_output: STFT output [n_frames, n_freq, 2]
-
- Returns:
- Magnitude spectrogram [n_frames, n_freq]
-
- Example:
- >>> stft_out = stft(buf, n_fft=512)
- >>> mag = magnitude_spectrum(stft_out)
- """
- native = _get_native()
- result = native.audio_magnitude_spectrum(stft_output._get_native())
- return GPUArray._wrap_native(result)
-
-
-def create_mel_filterbank(
- n_mels: int = 80,
- n_fft: int = 512,
- sample_rate: int = 16000,
- f_min: float = 0.0,
- f_max: float = -1.0,
-) -> GPUArray:
- """Create Mel filterbank matrix.
-
- Args:
- n_mels: Number of mel bands (default 80 for Whisper)
- n_fft: FFT size
- sample_rate: Sample rate in Hz
- f_min: Minimum frequency (default 0)
- f_max: Maximum frequency (default sample_rate/2)
-
- Returns:
- Mel filterbank matrix [n_mels, n_fft/2+1]
-
- Example:
- >>> mel_fb = create_mel_filterbank(n_mels=80, n_fft=512, sample_rate=16000)
- """
- native = _get_native()
- result = native.audio_create_mel_filterbank(n_mels, n_fft, sample_rate, f_min, f_max)
- return GPUArray._wrap_native(result)
-
-
-def apply_mel_filterbank(spectrogram: GPUArray, mel_filterbank: GPUArray) -> GPUArray:
- """Apply Mel filterbank to power/magnitude spectrogram.
-
- Args:
- spectrogram: Input spectrogram [n_frames, n_fft/2+1]
- mel_filterbank: Mel filterbank [n_mels, n_fft/2+1]
-
- Returns:
- Mel spectrogram [n_frames, n_mels]
-
- Example:
- >>> power = power_spectrum(stft_out)
- >>> mel_fb = create_mel_filterbank(n_mels=80, n_fft=512)
- >>> mel = apply_mel_filterbank(power, mel_fb)
- """
- native = _get_native()
- result = native.audio_apply_mel_filterbank(
- spectrogram._get_native(), mel_filterbank._get_native()
- )
- return GPUArray._wrap_native(result)
-
-
-def log_mel(mel_spectrogram: GPUArray, eps: float = 1e-10) -> GPUArray:
- """Compute log-mel spectrogram.
-
- log_mel = log(mel + eps)
-
- Args:
- mel_spectrogram: Mel spectrogram [n_frames, n_mels]
- eps: Small constant for numerical stability (default 1e-10)
-
- Returns:
- Log-mel spectrogram [n_frames, n_mels]
-
- Example:
- >>> log_mel_spec = log_mel(mel_spectrogram)
- """
- native = _get_native()
- result = native.audio_log_mel_spectrogram(mel_spectrogram._get_native(), eps)
- return GPUArray._wrap_native(result)
-
-
-def to_decibels(audio: AudioBuffer | GPUArray, eps: float = 1e-10) -> GPUArray:
- """Convert to decibels.
-
- dB = 10 * log10(x + eps)
-
- Args:
- audio: Input array (power values)
- eps: Small constant for numerical stability (default 1e-10)
-
- Returns:
- dB values
-
- Example:
- >>> power = power_spectrum(stft_out)
- >>> db = to_decibels(power)
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- data = audio.data
- else:
- data = audio
-
- result = native.audio_to_decibels(data._get_native(), eps)
- return GPUArray._wrap_native(result)
-
-
-def mfcc(log_mel_input: GPUArray, n_mfcc: int = 13) -> GPUArray:
- """Compute MFCC from log-mel spectrogram using DCT-II.
-
- Args:
- log_mel_input: Log-mel spectrogram [n_frames, n_mels]
- n_mfcc: Number of MFCC coefficients (default 13)
-
- Returns:
- MFCC [n_frames, n_mfcc]
-
- Example:
- >>> log_mel_spec = log_mel(mel_spectrogram)
- >>> mfcc_features = mfcc(log_mel_spec, n_mfcc=13)
- """
- native = _get_native()
- result = native.audio_mfcc(log_mel_input._get_native(), n_mfcc)
- return GPUArray._wrap_native(result)
-
-
-def delta(features: GPUArray, order: int = 1, width: int = 2) -> GPUArray:
- """Compute delta (differential) features.
-
- Args:
- features: Input features [n_frames, n_features]
- order: Delta order (1 for delta, 2 for delta-delta)
- width: Window width for computation (default 2)
-
- Returns:
- Delta features [n_frames, n_features]
-
- Example:
- >>> mfcc_features = mfcc(log_mel_spec)
- >>> delta_mfcc = delta(mfcc_features, order=1)
- >>> delta_delta_mfcc = delta(mfcc_features, order=2)
- """
- native = _get_native()
- result = native.audio_delta_features(features._get_native(), order, width)
- return GPUArray._wrap_native(result)
-
-
-def mel_spectrogram(
- audio: AudioBuffer | GPUArray,
- n_fft: int = 512,
- hop_length: int = 160,
- n_mels: int = 80,
- sample_rate: int = 16000,
- f_min: float = 0.0,
- f_max: float = -1.0,
-) -> GPUArray:
- """Compute mel spectrogram.
-
- Combines: STFT -> power -> mel filterbank
-
- Args:
- audio: Input audio (float32)
- n_fft: FFT size (must be power of 2)
- hop_length: Hop size
- n_mels: Number of mel bands
- sample_rate: Sample rate in Hz
- f_min: Minimum frequency
- f_max: Maximum frequency (-1 for sample_rate/2)
-
- Returns:
- Mel spectrogram [n_frames, n_mels]
-
- Example:
- >>> mel = mel_spectrogram(buf, n_fft=512, hop_length=160, n_mels=80)
- """
- if isinstance(audio, AudioBuffer):
- data = audio.data
- else:
- data = audio
-
- # STFT
- stft_out = stft(data, n_fft=n_fft, hop_length=hop_length, center=True)
-
- # Power spectrum
- power = power_spectrum(stft_out)
-
- # Create and apply mel filterbank
- mel_fb = create_mel_filterbank(n_mels, n_fft, sample_rate, f_min, f_max)
- mel = apply_mel_filterbank(power, mel_fb)
-
- return mel
-
-
-def log_mel_spectrogram(
- audio: AudioBuffer | GPUArray,
- n_fft: int = 512,
- hop_length: int = 160,
- n_mels: int = 80,
- sample_rate: int = 16000,
- f_min: float = 0.0,
- f_max: float = -1.0,
- eps: float = 1e-10,
-) -> GPUArray:
- """Compute log-mel spectrogram (Whisper-compatible).
-
- Combines: STFT -> power -> mel filterbank -> log
-
- Args:
- audio: Input audio (float32, 16kHz expected for Whisper)
- n_fft: FFT size (must be power of 2)
- hop_length: Hop size
- n_mels: Number of mel bands (80 for Whisper)
- sample_rate: Sample rate in Hz
- f_min: Minimum frequency
- f_max: Maximum frequency (-1 for sample_rate/2)
- eps: Small constant for log stability
-
- Returns:
- Log-mel spectrogram [n_frames, n_mels]
-
- Example:
- >>> # Whisper-style mel spectrogram
- >>> buf = from_pcm(pcm_data, sample_rate=16000)
- >>> log_mel = log_mel_spectrogram(buf, n_fft=512, hop_length=160, n_mels=80)
- """
- mel = mel_spectrogram(audio, n_fft, hop_length, n_mels, sample_rate, f_min, f_max)
- return log_mel(mel, eps)
-
-
-# =============================================================================
-# Inverse STFT and Phase Reconstruction
-# =============================================================================
-
-
-def istft(
- stft_output: GPUArray,
- hop_length: int = 160,
- win_length: int = -1,
- center: bool = True,
- length: int = -1,
-) -> GPUArray:
- """Compute Inverse Short-Time Fourier Transform (ISTFT).
-
- Reconstructs time-domain signal from complex STFT representation
- using overlap-add with window sum normalization.
-
- Args:
- stft_output: Complex STFT [n_frames, n_freq, 2] (real, imag)
- hop_length: Hop size (default 160)
- win_length: Window length (default: (n_freq-1)*2)
- center: Whether input was centered (default True)
- length: Output length (-1 for automatic)
-
- Returns:
- Time-domain signal [n_samples]
-
- Example:
- >>> stft_out = stft(buf, n_fft=512, hop_length=160)
- >>> reconstructed = istft(stft_out, hop_length=160)
- """
- native = _get_native()
- result = native.audio_istft(stft_output._get_native(), hop_length, win_length, center, length)
- return GPUArray._wrap_native(result)
-
-
-def griffin_lim(
- magnitude: GPUArray,
- n_iter: int = 32,
- hop_length: int = 160,
- win_length: int = -1,
-) -> GPUArray:
- """Griffin-Lim algorithm for phase reconstruction.
-
- Reconstructs time-domain signal from magnitude spectrogram only,
- iteratively estimating phase using STFT/ISTFT consistency.
-
- Args:
- magnitude: Magnitude spectrogram [n_frames, n_freq]
- n_iter: Number of iterations (default 32)
- hop_length: Hop size (default 160)
- win_length: Window length (default: (n_freq-1)*2)
-
- Returns:
- Reconstructed time-domain signal [n_samples]
-
- Example:
- >>> mag = magnitude_spectrum(stft_out)
- >>> reconstructed = griffin_lim(mag, n_iter=32)
- """
- native = _get_native()
- result = native.audio_griffin_lim(magnitude._get_native(), n_iter, hop_length, win_length)
- return GPUArray._wrap_native(result)
-
-
-# =============================================================================
-# Pitch Detection
-# =============================================================================
-
-
-def autocorrelation(audio: AudioBuffer | GPUArray, max_lag: int) -> GPUArray:
- """Compute autocorrelation function.
-
- Args:
- audio: Input audio (float32)
- max_lag: Maximum lag in samples
-
- Returns:
- Autocorrelation values [max_lag]
-
- Example:
- >>> acf = autocorrelation(buf, max_lag=400) # 25ms @ 16kHz
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- data = audio.data
- else:
- data = audio
-
- result = native.audio_autocorrelation(data._get_native(), max_lag)
- return GPUArray._wrap_native(result)
-
-
-def detect_pitch_yin(
- audio: AudioBuffer | GPUArray,
- sample_rate: int = 16000,
- f_min: float = 50.0,
- f_max: float = 500.0,
- threshold: float = 0.1,
-) -> float:
- """Detect pitch using YIN algorithm.
-
- The YIN algorithm detects the fundamental frequency of a quasi-periodic
- signal using cumulative mean normalized difference function.
-
- Args:
- audio: Input audio frame (float32)
- sample_rate: Sample rate in Hz
- f_min: Minimum frequency to detect (default 50 Hz)
- f_max: Maximum frequency to detect (default 500 Hz)
- threshold: YIN threshold (default 0.1)
-
- Returns:
- Detected pitch in Hz (0.0 if unvoiced)
-
- Example:
- >>> pitch = detect_pitch_yin(audio_frame, sample_rate=16000)
- >>> if pitch > 0:
- ... print(f"Pitch: {pitch:.1f} Hz")
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- data = audio.data
- else:
- data = audio
-
- return native.audio_detect_pitch_yin(data._get_native(), sample_rate, f_min, f_max, threshold)
-
-
-def detect_pitch_yin_frames(
- audio: AudioBuffer | GPUArray,
- sample_rate: int = 16000,
- frame_size: int = 1024,
- hop_size: int = 256,
- f_min: float = 50.0,
- f_max: float = 500.0,
- threshold: float = 0.1,
-) -> GPUArray:
- """Detect pitch for each frame using YIN algorithm.
-
- Args:
- audio: Input audio (float32)
- sample_rate: Sample rate in Hz
- frame_size: Frame size in samples (default 1024)
- hop_size: Hop size in samples (default 256)
- f_min: Minimum frequency to detect (default 50 Hz)
- f_max: Maximum frequency to detect (default 500 Hz)
- threshold: YIN threshold (default 0.1)
-
- Returns:
- Pitch values for each frame [n_frames]
-
- Example:
- >>> pitches = detect_pitch_yin_frames(buf, sample_rate=16000)
- >>> voiced = pitches.to_numpy() > 0
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- data = audio.data
- else:
- data = audio
-
- result = native.audio_detect_pitch_yin_frames(
- data._get_native(), sample_rate, frame_size, hop_size, f_min, f_max, threshold
- )
- return GPUArray._wrap_native(result)
-
-
-# =============================================================================
-# Spectral Features
-# =============================================================================
-
-
-def spectral_centroid(
- spectrum: GPUArray,
- sample_rate: int = 16000,
-) -> GPUArray:
- """Compute spectral centroid for each frame.
-
- The spectral centroid indicates the "center of mass" of the spectrum.
-
- Args:
- spectrum: Magnitude or power spectrum [n_frames, n_freq]
- sample_rate: Sample rate in Hz
-
- Returns:
- Spectral centroid in Hz for each frame [n_frames]
-
- Example:
- >>> mag = magnitude_spectrum(stft_out)
- >>> centroid = spectral_centroid(mag, sample_rate=16000)
- """
- native = _get_native()
- result = native.audio_spectral_centroid(spectrum._get_native(), sample_rate)
- return GPUArray._wrap_native(result)
-
-
-def spectral_bandwidth(
- spectrum: GPUArray,
- centroids: GPUArray,
- sample_rate: int = 16000,
- p: int = 2,
-) -> GPUArray:
- """Compute spectral bandwidth for each frame.
-
- Spectral bandwidth is the weighted standard deviation of frequencies
- around the spectral centroid.
-
- Args:
- spectrum: Magnitude or power spectrum [n_frames, n_freq]
- centroids: Pre-computed spectral centroids [n_frames]
- sample_rate: Sample rate in Hz
- p: Order for bandwidth computation (default 2)
-
- Returns:
- Spectral bandwidth in Hz for each frame [n_frames]
-
- Example:
- >>> mag = magnitude_spectrum(stft_out)
- >>> centroid = spectral_centroid(mag, sample_rate=16000)
- >>> bandwidth = spectral_bandwidth(mag, centroid, sample_rate=16000)
- """
- native = _get_native()
- result = native.audio_spectral_bandwidth(
- spectrum._get_native(), centroids._get_native(), sample_rate, p
- )
- return GPUArray._wrap_native(result)
-
-
-def spectral_rolloff(
- spectrum: GPUArray,
- sample_rate: int = 16000,
- roll_percent: float = 0.85,
-) -> GPUArray:
- """Compute spectral rolloff for each frame.
-
- The rolloff frequency is the frequency below which roll_percent of
- the total spectral energy is contained.
-
- Args:
- spectrum: Magnitude or power spectrum [n_frames, n_freq]
- sample_rate: Sample rate in Hz
- roll_percent: Percentage of energy (default 0.85)
-
- Returns:
- Rolloff frequency in Hz for each frame [n_frames]
-
- Example:
- >>> mag = magnitude_spectrum(stft_out)
- >>> rolloff = spectral_rolloff(mag, sample_rate=16000, roll_percent=0.85)
- """
- native = _get_native()
- result = native.audio_spectral_rolloff(spectrum._get_native(), sample_rate, roll_percent)
- return GPUArray._wrap_native(result)
-
-
-def spectral_flatness(spectrum: GPUArray) -> GPUArray:
- """Compute spectral flatness for each frame.
-
- Spectral flatness measures how tone-like vs noise-like a sound is.
- Values close to 1 indicate noise, values close to 0 indicate tonal content.
-
- Computed as: geometric_mean / arithmetic_mean
-
- Args:
- spectrum: Magnitude or power spectrum [n_frames, n_freq]
-
- Returns:
- Spectral flatness for each frame [n_frames] (0 to 1)
-
- Example:
- >>> mag = magnitude_spectrum(stft_out)
- >>> flatness = spectral_flatness(mag)
- """
- native = _get_native()
- result = native.audio_spectral_flatness(spectrum._get_native())
- return GPUArray._wrap_native(result)
-
-
-def spectral_contrast(
- spectrum: GPUArray,
- n_bands: int = 6,
- alpha: float = 0.2,
-) -> GPUArray:
- """Compute spectral contrast for each frame.
-
- Spectral contrast measures the difference between peaks and valleys
- in the spectrum, divided into frequency bands.
-
- Args:
- spectrum: Magnitude or power spectrum [n_frames, n_freq]
- n_bands: Number of frequency bands (default 6)
- alpha: Percentile for peak/valley estimation (default 0.2)
-
- Returns:
- Spectral contrast [n_frames, n_bands]
-
- Example:
- >>> mag = magnitude_spectrum(stft_out)
- >>> contrast = spectral_contrast(mag, n_bands=6)
- """
- native = _get_native()
- result = native.audio_spectral_contrast(spectrum._get_native(), n_bands, alpha)
- return GPUArray._wrap_native(result)
-
-
-def zero_crossing_rate(
- audio: AudioBuffer | GPUArray,
- frame_size: int = 512,
- hop_size: int = 256,
-) -> GPUArray:
- """Compute zero-crossing rate for each frame.
-
- ZCR counts the number of times the signal crosses zero per frame,
- normalized by frame size.
-
- Args:
- audio: Input audio (float32)
- frame_size: Frame size in samples (default 512)
- hop_size: Hop size in samples (default 256)
-
- Returns:
- Zero-crossing rate for each frame [n_frames]
-
- Example:
- >>> zcr = zero_crossing_rate(buf, frame_size=512, hop_size=256)
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- data = audio.data
- else:
- data = audio
-
- result = native.audio_zero_crossing_rate(data._get_native(), frame_size, hop_size)
- return GPUArray._wrap_native(result)
-
-
-# =============================================================================
-# Constant-Q Transform and Chromagram
-# =============================================================================
-
-
-def cqt(
- audio: AudioBuffer | GPUArray,
- sample_rate: int = 16000,
- hop_length: int = 160,
- f_min: float = 32.7,
- n_bins: int = 84,
- bins_per_octave: int = 12,
-) -> GPUArray:
- """Compute Constant-Q Transform (CQT).
-
- CQT provides logarithmically-spaced frequency resolution, useful for
- music analysis where notes are logarithmically distributed.
-
- This implementation uses STFT-based approximation for efficiency.
-
- Args:
- audio: Input audio (float32)
- sample_rate: Sample rate in Hz
- hop_length: Hop size (default 160)
- f_min: Minimum frequency (default 32.7 Hz = C1)
- n_bins: Number of frequency bins (default 84 = 7 octaves)
- bins_per_octave: Bins per octave (default 12)
-
- Returns:
- Complex CQT [n_frames, n_bins, 2] (real, imag)
-
- Example:
- >>> cqt_out = cqt(buf, sample_rate=16000, n_bins=84)
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- data = audio.data
- else:
- data = audio
-
- result = native.audio_cqt(
- data._get_native(), sample_rate, hop_length, f_min, n_bins, bins_per_octave
- )
- return GPUArray._wrap_native(result)
-
-
-def cqt_magnitude(
- audio: AudioBuffer | GPUArray,
- sample_rate: int = 16000,
- hop_length: int = 160,
- f_min: float = 32.7,
- n_bins: int = 84,
- bins_per_octave: int = 12,
-) -> GPUArray:
- """Compute CQT magnitude spectrogram.
-
- Convenience function that computes CQT and returns magnitude.
-
- Args:
- audio: Input audio (float32)
- sample_rate: Sample rate in Hz
- hop_length: Hop size (default 160)
- f_min: Minimum frequency (default 32.7 Hz = C1)
- n_bins: Number of frequency bins (default 84)
- bins_per_octave: Bins per octave (default 12)
-
- Returns:
- CQT magnitude [n_frames, n_bins]
-
- Example:
- >>> cqt_mag = cqt_magnitude(buf, sample_rate=16000)
- """
- cqt_out = cqt(audio, sample_rate, hop_length, f_min, n_bins, bins_per_octave)
- return magnitude_spectrum(cqt_out)
-
-
-def chroma_stft(
- spectrum: GPUArray,
- sample_rate: int = 16000,
- n_chroma: int = 12,
- tuning: float = 0.0,
-) -> GPUArray:
- """Compute chromagram from STFT magnitude spectrum.
-
- Maps the spectrum to 12 pitch classes (C, C#, D, ..., B).
-
- Args:
- spectrum: Magnitude spectrum [n_frames, n_freq]
- sample_rate: Sample rate in Hz
- n_chroma: Number of chroma bins (default 12)
- tuning: Tuning deviation in fractions of a chroma bin (default 0)
-
- Returns:
- Chromagram [n_frames, n_chroma]
-
- Example:
- >>> mag = magnitude_spectrum(stft_out)
- >>> chroma = chroma_stft(mag, sample_rate=16000)
- """
- native = _get_native()
- result = native.audio_chroma_stft(spectrum._get_native(), sample_rate, n_chroma, tuning)
- return GPUArray._wrap_native(result)
-
-
-def chroma_cqt(
- cqt_magnitude_input: GPUArray,
- bins_per_octave: int = 12,
-) -> GPUArray:
- """Compute chromagram from CQT magnitude.
-
- Args:
- cqt_magnitude_input: CQT magnitude [n_frames, n_bins]
- bins_per_octave: Bins per octave in CQT (default 12)
-
- Returns:
- Chromagram [n_frames, bins_per_octave]
-
- Example:
- >>> cqt_mag = cqt_magnitude(buf, bins_per_octave=12)
- >>> chroma = chroma_cqt(cqt_mag, bins_per_octave=12)
- """
- native = _get_native()
- result = native.audio_chroma_cqt(cqt_magnitude_input._get_native(), bins_per_octave)
- return GPUArray._wrap_native(result)
-
-
-# =============================================================================
-# Harmonic-Percussive Source Separation (HPSS)
-# =============================================================================
-
-
-def hpss(
- stft_magnitude_input: GPUArray,
- kernel_size: int = 31,
- power: float = 2.0,
- margin: float = 1.0,
-) -> tuple[GPUArray, GPUArray]:
- """Harmonic-Percussive Source Separation using median filtering.
-
- Separates audio into harmonic (tonal) and percussive (transient) components
- using median filtering in time and frequency directions.
-
- Args:
- stft_magnitude_input: STFT magnitude [n_frames, n_freq]
- kernel_size: Median filter kernel size (default 31)
- power: Power for spectrogram (default 2.0)
- margin: Margin for soft masking (default 1.0)
-
- Returns:
- Tuple of (harmonic_magnitude, percussive_magnitude)
-
- Example:
- >>> mag = magnitude_spectrum(stft_out)
- >>> harmonic, percussive = hpss(mag)
- """
- native = _get_native()
- h, p = native.audio_hpss(stft_magnitude_input._get_native(), kernel_size, power, margin)
- return GPUArray._wrap_native(h), GPUArray._wrap_native(p)
-
-
-def harmonic(
- stft_magnitude_input: GPUArray,
- kernel_size: int = 31,
- power: float = 2.0,
- margin: float = 1.0,
-) -> GPUArray:
- """Extract harmonic component using HPSS.
-
- Args:
- stft_magnitude_input: STFT magnitude [n_frames, n_freq]
- kernel_size: Median filter kernel size (default 31)
- power: Power for spectrogram (default 2.0)
- margin: Margin for soft masking (default 1.0)
-
- Returns:
- Harmonic magnitude [n_frames, n_freq]
-
- Example:
- >>> mag = magnitude_spectrum(stft_out)
- >>> harm = harmonic(mag)
- """
- h, _ = hpss(stft_magnitude_input, kernel_size, power, margin)
- return h
-
-
-def percussive(
- stft_magnitude_input: GPUArray,
- kernel_size: int = 31,
- power: float = 2.0,
- margin: float = 1.0,
-) -> GPUArray:
- """Extract percussive component using HPSS.
-
- Args:
- stft_magnitude_input: STFT magnitude [n_frames, n_freq]
- kernel_size: Median filter kernel size (default 31)
- power: Power for spectrogram (default 2.0)
- margin: Margin for soft masking (default 1.0)
-
- Returns:
- Percussive magnitude [n_frames, n_freq]
-
- Example:
- >>> mag = magnitude_spectrum(stft_out)
- >>> perc = percussive(mag)
- """
- _, p = hpss(stft_magnitude_input, kernel_size, power, margin)
- return p
-
-
-# =============================================================================
-# Time Stretching and Pitch Shifting
-# =============================================================================
-
-
-def time_stretch(
- audio: AudioBuffer | GPUArray,
- rate: float,
- n_fft: int = 2048,
- hop_length: int = 512,
-) -> GPUArray:
- """Time stretch audio using phase vocoder.
-
- Changes the duration of audio without changing its pitch.
-
- Args:
- audio: Input audio (float32)
- rate: Stretch factor (>1 = faster/shorter, <1 = slower/longer)
- n_fft: FFT size (default 2048)
- hop_length: Hop size (default 512)
-
- Returns:
- Time-stretched audio [n_samples * rate]
-
- Example:
- >>> # Slow down to half speed
- >>> slow = time_stretch(buf, rate=0.5)
- >>> # Speed up to double speed
- >>> fast = time_stretch(buf, rate=2.0)
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- data = audio.data
- else:
- data = audio
-
- result = native.audio_time_stretch(data._get_native(), rate, n_fft, hop_length)
- return GPUArray._wrap_native(result)
-
-
-def pitch_shift(
- audio: AudioBuffer | GPUArray,
- sample_rate: int,
- n_steps: float,
- n_fft: int = 2048,
- hop_length: int = 512,
-) -> GPUArray:
- """Pitch shift audio using phase vocoder and resampling.
-
- Changes the pitch of audio without changing its duration.
-
- Args:
- audio: Input audio (float32)
- sample_rate: Sample rate in Hz
- n_steps: Number of semitones to shift (positive = up, negative = down)
- n_fft: FFT size (default 2048)
- hop_length: Hop size (default 512)
-
- Returns:
- Pitch-shifted audio [n_samples]
-
- Example:
- >>> # Shift up one octave
- >>> higher = pitch_shift(buf, sample_rate=16000, n_steps=12)
- >>> # Shift down a perfect fifth
- >>> lower = pitch_shift(buf, sample_rate=16000, n_steps=-7)
- """
- native = _get_native()
-
- if isinstance(audio, AudioBuffer):
- data = audio.data
- else:
- data = audio
-
- result = native.audio_pitch_shift(data._get_native(), sample_rate, n_steps, n_fft, hop_length)
- return GPUArray._wrap_native(result)
-
-
-__all__ = [
- # Classes
- "AudioBuffer",
- "AudioRingBuffer",
- "AudioStream",
- "SpeechSegment",
- "VAD",
- # Basic functions
- "from_pcm",
- # Preprocessing functions
- "preemphasis",
- "deemphasis",
- "remove_dc",
- "highpass_filter",
- "noise_gate",
- "spectral_gate",
- "compute_short_term_energy",
- # Spectral processing
- "stft",
- "power_spectrum",
- "magnitude_spectrum",
- "create_mel_filterbank",
- "apply_mel_filterbank",
- "log_mel",
- "to_decibels",
- "mfcc",
- "delta",
- # High-level functions
- "mel_spectrogram",
- "log_mel_spectrogram",
- # Inverse STFT and phase reconstruction
- "istft",
- "griffin_lim",
- # Pitch detection
- "autocorrelation",
- "detect_pitch_yin",
- "detect_pitch_yin_frames",
- # Spectral features
- "spectral_centroid",
- "spectral_bandwidth",
- "spectral_rolloff",
- "spectral_flatness",
- "spectral_contrast",
- "zero_crossing_rate",
- # CQT and Chromagram
- "cqt",
- "cqt_magnitude",
- "chroma_stft",
- "chroma_cqt",
- # HPSS
- "hpss",
- "harmonic",
- "percussive",
- # Time stretching and pitch shifting
- "time_stretch",
- "pitch_shift",
-]
diff --git a/src/pygpukit/ops/audio/__init__.py b/src/pygpukit/ops/audio/__init__.py
new file mode 100644
index 0000000..0ce11e8
--- /dev/null
+++ b/src/pygpukit/ops/audio/__init__.py
@@ -0,0 +1,167 @@
+"""GPU Audio Processing Operations.
+
+This module provides GPU-accelerated audio processing for ASR/Whisper preprocessing:
+- PCM to float conversion
+- Stereo to mono conversion
+- Peak/RMS normalization
+- Resampling (48kHz -> 16kHz)
+
+Example:
+ >>> import numpy as np
+ >>> import pygpukit as gk
+ >>> from pygpukit.ops import audio
+ >>>
+ >>> # Load PCM samples (int16)
+ >>> pcm = np.array([0, 16384, -16384, 32767], dtype=np.int16)
+ >>> buf = audio.from_pcm(pcm, sample_rate=48000)
+ >>>
+ >>> # Process audio
+ >>> buf = buf.to_mono().resample(16000).normalize()
+ >>> result = buf.data.to_numpy()
+
+Corresponds to native/ops/audio/.
+"""
+
+from __future__ import annotations
+
+# Buffer classes
+from .buffer import (
+ AudioBuffer,
+ AudioRingBuffer,
+ AudioStream,
+ from_pcm,
+)
+
+# CQT and Chromagram
+from .cqt import (
+ chroma_cqt,
+ chroma_stft,
+ cqt,
+ cqt_magnitude,
+)
+
+# Audio effects
+from .effects import (
+ pitch_shift,
+ time_stretch,
+)
+
+# Spectral features
+from .features import (
+ spectral_bandwidth,
+ spectral_centroid,
+ spectral_contrast,
+ spectral_flatness,
+ spectral_rolloff,
+ zero_crossing_rate,
+)
+
+# HPSS
+from .hpss import (
+ harmonic,
+ hpss,
+ percussive,
+)
+
+# Phase reconstruction
+from .phase import (
+ griffin_lim,
+ istft,
+)
+
+# Pitch detection
+from .pitch import (
+ autocorrelation,
+ detect_pitch_yin,
+ detect_pitch_yin_frames,
+)
+
+# Preprocessing functions
+from .preprocessing import (
+ compute_short_term_energy,
+ deemphasis,
+ highpass_filter,
+ noise_gate,
+ preemphasis,
+ remove_dc,
+ spectral_gate,
+)
+
+# Spectral processing
+from .spectral import (
+ apply_mel_filterbank,
+ create_mel_filterbank,
+ delta,
+ log_mel,
+ log_mel_spectrogram,
+ magnitude_spectrum,
+ mel_spectrogram,
+ mfcc,
+ power_spectrum,
+ stft,
+ to_decibels,
+)
+
+# VAD
+from .vad import (
+ VAD,
+ SpeechSegment,
+)
+
+__all__ = [
+ # Classes
+ "AudioBuffer",
+ "AudioRingBuffer",
+ "AudioStream",
+ "SpeechSegment",
+ "VAD",
+ # Basic functions
+ "from_pcm",
+ # Preprocessing functions
+ "preemphasis",
+ "deemphasis",
+ "remove_dc",
+ "highpass_filter",
+ "noise_gate",
+ "spectral_gate",
+ "compute_short_term_energy",
+ # Spectral processing
+ "stft",
+ "power_spectrum",
+ "magnitude_spectrum",
+ "create_mel_filterbank",
+ "apply_mel_filterbank",
+ "log_mel",
+ "to_decibels",
+ "mfcc",
+ "delta",
+ # High-level functions
+ "mel_spectrogram",
+ "log_mel_spectrogram",
+ # Inverse STFT and phase reconstruction
+ "istft",
+ "griffin_lim",
+ # Pitch detection
+ "autocorrelation",
+ "detect_pitch_yin",
+ "detect_pitch_yin_frames",
+ # Spectral features
+ "spectral_centroid",
+ "spectral_bandwidth",
+ "spectral_rolloff",
+ "spectral_flatness",
+ "spectral_contrast",
+ "zero_crossing_rate",
+ # CQT and Chromagram
+ "cqt",
+ "cqt_magnitude",
+ "chroma_stft",
+ "chroma_cqt",
+ # HPSS
+ "hpss",
+ "harmonic",
+ "percussive",
+ # Time stretching and pitch shifting
+ "time_stretch",
+ "pitch_shift",
+]
diff --git a/src/pygpukit/ops/audio/buffer.py b/src/pygpukit/ops/audio/buffer.py
new file mode 100644
index 0000000..d4a7bea
--- /dev/null
+++ b/src/pygpukit/ops/audio/buffer.py
@@ -0,0 +1,426 @@
+"""Audio buffer classes for GPU audio processing.
+
+This module provides:
+- AudioBuffer: GPU audio buffer with metadata
+- AudioRingBuffer: GPU-side ring buffer for streaming
+- AudioStream: High-level streaming audio processor
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import numpy as np
+
+from pygpukit.core import GPUArray
+from pygpukit.core import from_numpy as core_from_numpy
+from pygpukit.core.dtypes import float32, int16
+
+
+def _get_native():
+ """Get the native module."""
+ try:
+ from pygpukit._native_loader import get_native_module
+
+ return get_native_module()
+ except ImportError:
+ from pygpukit import _pygpukit_native
+
+ return _pygpukit_native
+
+
+@dataclass
+class AudioBuffer:
+ """GPU audio buffer with metadata.
+
+ Attributes:
+ data: GPUArray containing audio samples (float32)
+ sample_rate: Sample rate in Hz
+ channels: Number of channels (1=mono, 2=stereo)
+ """
+
+ data: GPUArray
+ sample_rate: int
+ channels: int
+
+ def to_mono(self) -> AudioBuffer:
+ """Convert stereo audio to mono.
+
+ Returns:
+ New AudioBuffer with mono audio (channels=1)
+
+ Raises:
+ ValueError: If already mono
+ """
+ if self.channels == 1:
+ return self
+
+ if self.channels != 2:
+ raise ValueError(f"to_mono only supports stereo (2 channels), got {self.channels}")
+
+ native = _get_native()
+ mono_data = native.audio_stereo_to_mono(self.data._get_native())
+
+ return AudioBuffer(
+ data=GPUArray._wrap_native(mono_data),
+ sample_rate=self.sample_rate,
+ channels=1,
+ )
+
+ def resample(self, target_rate: int) -> AudioBuffer:
+ """Resample audio to target sample rate.
+
+ Currently supports:
+ - 48000 -> 16000 (3:1 decimation for Whisper)
+
+ Args:
+ target_rate: Target sample rate in Hz
+
+ Returns:
+ New AudioBuffer with resampled audio
+
+ Raises:
+ ValueError: If sample rate conversion is not supported
+ """
+ if self.sample_rate == target_rate:
+ return self
+
+ native = _get_native()
+ resampled = native.audio_resample(self.data._get_native(), self.sample_rate, target_rate)
+
+ return AudioBuffer(
+ data=GPUArray._wrap_native(resampled),
+ sample_rate=target_rate,
+ channels=self.channels,
+ )
+
+ def normalize(self, mode: str = "peak", target_db: float = -20.0) -> AudioBuffer:
+ """Normalize audio level.
+
+ Args:
+ mode: Normalization mode ("peak" or "rms")
+ target_db: Target level in dB (only used for RMS mode)
+
+ Returns:
+ Self (in-place normalization)
+
+ Raises:
+ ValueError: If mode is not "peak" or "rms"
+ """
+ native = _get_native()
+
+ if mode == "peak":
+ native.audio_normalize_peak(self.data._get_native())
+ elif mode == "rms":
+ native.audio_normalize_rms(self.data._get_native(), target_db)
+ else:
+ raise ValueError(f"Unknown normalization mode: {mode}. Use 'peak' or 'rms'.")
+
+ return self
+
+ def to_numpy(self) -> np.ndarray:
+ """Convert audio data to NumPy array.
+
+ Returns:
+ NumPy array of float32 samples
+ """
+ return self.data.to_numpy()
+
+ def __repr__(self) -> str:
+ return (
+ f"AudioBuffer(samples={self.data.shape[0]}, "
+ f"sample_rate={self.sample_rate}, channels={self.channels})"
+ )
+
+
+def from_pcm(
+ samples: np.ndarray | GPUArray,
+ sample_rate: int,
+ channels: int = 1,
+) -> AudioBuffer:
+ """Create AudioBuffer from PCM samples.
+
+ Args:
+ samples: PCM samples as int16 or float32 array
+ sample_rate: Sample rate in Hz (e.g., 48000, 16000)
+ channels: Number of channels (1=mono, 2=stereo)
+
+ Returns:
+ AudioBuffer with audio data on GPU
+
+ Example:
+ >>> pcm = np.array([0, 16384, -16384], dtype=np.int16)
+ >>> buf = from_pcm(pcm, sample_rate=48000)
+ """
+ native = _get_native()
+
+ # Convert to GPUArray if needed
+ if isinstance(samples, np.ndarray):
+ gpu_samples = core_from_numpy(samples)
+ else:
+ gpu_samples = samples
+
+ # Convert int16 PCM to float32
+ if gpu_samples.dtype == int16:
+ float_data = native.audio_pcm_to_float32(gpu_samples._get_native())
+ gpu_data = GPUArray._wrap_native(float_data)
+ elif gpu_samples.dtype == float32:
+ # Already float32, just use as-is
+ gpu_data = gpu_samples
+ else:
+ raise ValueError(f"Unsupported dtype: {gpu_samples.dtype}. Use int16 or float32.")
+
+ return AudioBuffer(
+ data=gpu_data,
+ sample_rate=sample_rate,
+ channels=channels,
+ )
+
+
+class AudioRingBuffer:
+ """GPU-side ring buffer for streaming audio.
+
+ Provides efficient circular buffer operations for real-time audio processing.
+
+ Args:
+ capacity: Buffer capacity in samples
+ sample_rate: Sample rate in Hz (for metadata)
+
+ Example:
+ >>> ring = AudioRingBuffer(capacity=48000, sample_rate=16000) # 3 sec buffer
+ >>> ring.write(chunk1)
+ >>> ring.write(chunk2)
+ >>> window = ring.read(16000) # Read 1 second
+ """
+
+ def __init__(self, capacity: int, sample_rate: int = 16000):
+ from pygpukit.core import zeros
+
+ self._buffer = zeros((capacity,), dtype="float32")
+ self._capacity = capacity
+ self._sample_rate = sample_rate
+ self._write_pos = 0
+ self._samples_written = 0
+
+ @property
+ def capacity(self) -> int:
+ """Buffer capacity in samples."""
+ return self._capacity
+
+ @property
+ def sample_rate(self) -> int:
+ """Sample rate in Hz."""
+ return self._sample_rate
+
+ @property
+ def samples_available(self) -> int:
+ """Number of samples available for reading."""
+ return min(self._samples_written, self._capacity)
+
+ @property
+ def duration_available(self) -> float:
+ """Duration of available audio in seconds."""
+ return self.samples_available / self._sample_rate
+
+ def write(self, samples: np.ndarray | GPUArray) -> int:
+ """Write samples to the ring buffer.
+
+ Args:
+ samples: Audio samples to write (float32)
+
+ Returns:
+ Number of samples written
+ """
+ native = _get_native()
+
+ # Convert to GPUArray if needed
+ if isinstance(samples, np.ndarray):
+ gpu_samples = core_from_numpy(samples.astype(np.float32))
+ else:
+ gpu_samples = samples
+
+ num_samples = gpu_samples.shape[0]
+
+ # Write to ring buffer
+ native.audio_ring_buffer_write(
+ gpu_samples._get_native(),
+ self._buffer._get_native(),
+ self._write_pos,
+ )
+
+ # Update write position
+ self._write_pos = (self._write_pos + num_samples) % self._capacity
+ self._samples_written += num_samples
+
+ return num_samples
+
+ def read(self, num_samples: int, offset: int = 0) -> GPUArray:
+ """Read samples from the ring buffer.
+
+ Args:
+ num_samples: Number of samples to read
+ offset: Offset from current read position (0 = most recent)
+
+ Returns:
+ GPUArray of audio samples
+ """
+ native = _get_native()
+
+ # Calculate read position (read from oldest available)
+ if self._samples_written <= self._capacity:
+ read_pos = offset
+ else:
+ read_pos = (self._write_pos + offset) % self._capacity
+
+ result = native.audio_ring_buffer_read(
+ self._buffer._get_native(),
+ read_pos,
+ num_samples,
+ )
+
+ return GPUArray._wrap_native(result)
+
+ def clear(self) -> None:
+ """Clear the buffer."""
+ from pygpukit.core import zeros
+
+ self._buffer = zeros((self._capacity,), dtype="float32")
+ self._write_pos = 0
+ self._samples_written = 0
+
+ def __repr__(self) -> str:
+ return (
+ f"AudioRingBuffer(capacity={self._capacity}, "
+ f"sample_rate={self._sample_rate}, "
+ f"available={self.samples_available})"
+ )
+
+
+class AudioStream:
+ """High-level streaming audio processor.
+
+ Provides chunked processing with windowing for smooth transitions.
+ Suitable for real-time ASR preprocessing.
+
+ Args:
+ chunk_size: Processing chunk size in samples (default: 480 = 30ms @ 16kHz)
+ hop_size: Hop size between chunks (default: chunk_size // 2 for 50% overlap)
+ sample_rate: Sample rate in Hz
+ buffer_duration: Ring buffer duration in seconds
+
+ Example:
+ >>> stream = AudioStream(chunk_size=480, sample_rate=16000)
+ >>> for pcm_chunk in audio_source:
+ ... stream.push(pcm_chunk)
+ ... if stream.has_chunk():
+ ... chunk = stream.pop_chunk()
+ ... # Process chunk for ASR
+ """
+
+ def __init__(
+ self,
+ chunk_size: int = 480,
+ hop_size: int | None = None,
+ sample_rate: int = 16000,
+ buffer_duration: float = 30.0,
+ ):
+ self._chunk_size = chunk_size
+ self._hop_size = hop_size if hop_size is not None else chunk_size // 2
+ self._sample_rate = sample_rate
+
+ # Ring buffer for incoming audio
+ buffer_samples = int(buffer_duration * sample_rate)
+ self._ring_buffer = AudioRingBuffer(buffer_samples, sample_rate)
+
+ # Track chunk position
+ self._chunks_processed = 0
+
+ @property
+ def chunk_size(self) -> int:
+ """Chunk size in samples."""
+ return self._chunk_size
+
+ @property
+ def hop_size(self) -> int:
+ """Hop size in samples."""
+ return self._hop_size
+
+ @property
+ def sample_rate(self) -> int:
+ """Sample rate in Hz."""
+ return self._sample_rate
+
+ def push(self, samples: np.ndarray | GPUArray) -> int:
+ """Push audio samples to the stream.
+
+ Args:
+ samples: Audio samples (float32)
+
+ Returns:
+ Number of samples pushed
+ """
+ return self._ring_buffer.write(samples)
+
+ def has_chunk(self) -> bool:
+ """Check if a full chunk is available."""
+ required = self._chunks_processed * self._hop_size + self._chunk_size
+ return self._ring_buffer._samples_written >= required
+
+ def pop_chunk(self, apply_window: bool = True) -> GPUArray:
+ """Pop the next chunk from the stream.
+
+ Args:
+ apply_window: Whether to apply Hann window (default True)
+
+ Returns:
+ GPUArray containing the chunk
+
+ Raises:
+ RuntimeError: If no chunk is available
+ """
+ if not self.has_chunk():
+ raise RuntimeError("No chunk available. Call has_chunk() first.")
+
+ native = _get_native()
+
+ # Calculate read offset
+ read_offset = self._chunks_processed * self._hop_size
+
+ # Read chunk from ring buffer
+ chunk = self._ring_buffer.read(self._chunk_size, read_offset)
+
+ # Apply window if requested
+ if apply_window:
+ native.audio_apply_hann_window(chunk._get_native())
+
+ self._chunks_processed += 1
+ return chunk
+
+ def reset(self) -> None:
+ """Reset the stream state."""
+ self._ring_buffer.clear()
+ self._chunks_processed = 0
+
+ @property
+ def chunks_available(self) -> int:
+ """Number of complete chunks available."""
+ if self._ring_buffer._samples_written < self._chunk_size:
+ return 0
+ available = self._ring_buffer._samples_written - self._chunk_size
+ return available // self._hop_size + 1 - self._chunks_processed
+
+ def __repr__(self) -> str:
+ return (
+ f"AudioStream(chunk_size={self._chunk_size}, "
+ f"hop_size={self._hop_size}, "
+ f"sample_rate={self._sample_rate}, "
+ f"chunks_available={self.chunks_available})"
+ )
+
+
+__all__ = [
+ "AudioBuffer",
+ "AudioRingBuffer",
+ "AudioStream",
+ "from_pcm",
+]
diff --git a/src/pygpukit/ops/audio/cqt.py b/src/pygpukit/ops/audio/cqt.py
new file mode 100644
index 0000000..6579e8e
--- /dev/null
+++ b/src/pygpukit/ops/audio/cqt.py
@@ -0,0 +1,155 @@
+"""Constant-Q Transform and Chromagram for GPU audio processing.
+
+This module provides:
+- CQT (Constant-Q Transform)
+- Chromagram from STFT and CQT
+"""
+
+from __future__ import annotations
+
+from pygpukit.core import GPUArray
+
+from .buffer import AudioBuffer
+from .spectral import magnitude_spectrum
+
+
+def _get_native():
+ """Get the native module."""
+ try:
+ from pygpukit._native_loader import get_native_module
+
+ return get_native_module()
+ except ImportError:
+ from pygpukit import _pygpukit_native
+
+ return _pygpukit_native
+
+
+def cqt(
+ audio: AudioBuffer | GPUArray,
+ sample_rate: int = 16000,
+ hop_length: int = 160,
+ f_min: float = 32.7,
+ n_bins: int = 84,
+ bins_per_octave: int = 12,
+) -> GPUArray:
+ """Compute Constant-Q Transform (CQT).
+
+ CQT provides logarithmically-spaced frequency resolution, useful for
+ music analysis where notes are logarithmically distributed.
+
+ This implementation uses STFT-based approximation for efficiency.
+
+ Args:
+ audio: Input audio (float32)
+ sample_rate: Sample rate in Hz
+ hop_length: Hop size (default 160)
+ f_min: Minimum frequency (default 32.7 Hz = C1)
+ n_bins: Number of frequency bins (default 84 = 7 octaves)
+ bins_per_octave: Bins per octave (default 12)
+
+ Returns:
+ Complex CQT [n_frames, n_bins, 2] (real, imag)
+
+ Example:
+ >>> cqt_out = cqt(buf, sample_rate=16000, n_bins=84)
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ data = audio.data
+ else:
+ data = audio
+
+ result = native.audio_cqt(
+ data._get_native(), sample_rate, hop_length, f_min, n_bins, bins_per_octave
+ )
+ return GPUArray._wrap_native(result)
+
+
+def cqt_magnitude(
+ audio: AudioBuffer | GPUArray,
+ sample_rate: int = 16000,
+ hop_length: int = 160,
+ f_min: float = 32.7,
+ n_bins: int = 84,
+ bins_per_octave: int = 12,
+) -> GPUArray:
+ """Compute CQT magnitude spectrogram.
+
+ Convenience function that computes CQT and returns magnitude.
+
+ Args:
+ audio: Input audio (float32)
+ sample_rate: Sample rate in Hz
+ hop_length: Hop size (default 160)
+ f_min: Minimum frequency (default 32.7 Hz = C1)
+ n_bins: Number of frequency bins (default 84)
+ bins_per_octave: Bins per octave (default 12)
+
+ Returns:
+ CQT magnitude [n_frames, n_bins]
+
+ Example:
+ >>> cqt_mag = cqt_magnitude(buf, sample_rate=16000)
+ """
+ cqt_out = cqt(audio, sample_rate, hop_length, f_min, n_bins, bins_per_octave)
+ return magnitude_spectrum(cqt_out)
+
+
+def chroma_stft(
+ spectrum: GPUArray,
+ sample_rate: int = 16000,
+ n_chroma: int = 12,
+ tuning: float = 0.0,
+) -> GPUArray:
+ """Compute chromagram from STFT magnitude spectrum.
+
+ Maps the spectrum to 12 pitch classes (C, C#, D, ..., B).
+
+ Args:
+ spectrum: Magnitude spectrum [n_frames, n_freq]
+ sample_rate: Sample rate in Hz
+ n_chroma: Number of chroma bins (default 12)
+ tuning: Tuning deviation in fractions of a chroma bin (default 0)
+
+ Returns:
+ Chromagram [n_frames, n_chroma]
+
+ Example:
+ >>> mag = magnitude_spectrum(stft_out)
+ >>> chroma = chroma_stft(mag, sample_rate=16000)
+ """
+ native = _get_native()
+ result = native.audio_chroma_stft(spectrum._get_native(), sample_rate, n_chroma, tuning)
+ return GPUArray._wrap_native(result)
+
+
+def chroma_cqt(
+ cqt_magnitude_input: GPUArray,
+ bins_per_octave: int = 12,
+) -> GPUArray:
+ """Compute chromagram from CQT magnitude.
+
+ Args:
+ cqt_magnitude_input: CQT magnitude [n_frames, n_bins]
+ bins_per_octave: Bins per octave in CQT (default 12)
+
+ Returns:
+ Chromagram [n_frames, bins_per_octave]
+
+ Example:
+ >>> cqt_mag = cqt_magnitude(buf, bins_per_octave=12)
+ >>> chroma = chroma_cqt(cqt_mag, bins_per_octave=12)
+ """
+ native = _get_native()
+ result = native.audio_chroma_cqt(cqt_magnitude_input._get_native(), bins_per_octave)
+ return GPUArray._wrap_native(result)
+
+
+__all__ = [
+ "cqt",
+ "cqt_magnitude",
+ "chroma_stft",
+ "chroma_cqt",
+]
diff --git a/src/pygpukit/ops/audio/effects.py b/src/pygpukit/ops/audio/effects.py
new file mode 100644
index 0000000..c16f766
--- /dev/null
+++ b/src/pygpukit/ops/audio/effects.py
@@ -0,0 +1,104 @@
+"""Audio effects for GPU audio processing.
+
+This module provides:
+- Time stretching using phase vocoder
+- Pitch shifting
+"""
+
+from __future__ import annotations
+
+from pygpukit.core import GPUArray
+
+from .buffer import AudioBuffer
+
+
+def _get_native():
+ """Get the native module."""
+ try:
+ from pygpukit._native_loader import get_native_module
+
+ return get_native_module()
+ except ImportError:
+ from pygpukit import _pygpukit_native
+
+ return _pygpukit_native
+
+
+def time_stretch(
+ audio: AudioBuffer | GPUArray,
+ rate: float,
+ n_fft: int = 2048,
+ hop_length: int = 512,
+) -> GPUArray:
+ """Time stretch audio using phase vocoder.
+
+ Changes the duration of audio without changing its pitch.
+
+ Args:
+ audio: Input audio (float32)
+ rate: Stretch factor (>1 = faster/shorter, <1 = slower/longer)
+ n_fft: FFT size (default 2048)
+ hop_length: Hop size (default 512)
+
+ Returns:
+ Time-stretched audio [n_samples * rate]
+
+ Example:
+ >>> # Slow down to half speed
+ >>> slow = time_stretch(buf, rate=0.5)
+ >>> # Speed up to double speed
+ >>> fast = time_stretch(buf, rate=2.0)
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ data = audio.data
+ else:
+ data = audio
+
+ result = native.audio_time_stretch(data._get_native(), rate, n_fft, hop_length)
+ return GPUArray._wrap_native(result)
+
+
+def pitch_shift(
+ audio: AudioBuffer | GPUArray,
+ sample_rate: int,
+ n_steps: float,
+ n_fft: int = 2048,
+ hop_length: int = 512,
+) -> GPUArray:
+ """Pitch shift audio using phase vocoder and resampling.
+
+ Changes the pitch of audio without changing its duration.
+
+ Args:
+ audio: Input audio (float32)
+ sample_rate: Sample rate in Hz
+ n_steps: Number of semitones to shift (positive = up, negative = down)
+ n_fft: FFT size (default 2048)
+ hop_length: Hop size (default 512)
+
+ Returns:
+ Pitch-shifted audio [n_samples]
+
+ Example:
+ >>> # Shift up one octave
+ >>> higher = pitch_shift(buf, sample_rate=16000, n_steps=12)
+ >>> # Shift down a perfect fifth
+ >>> lower = pitch_shift(buf, sample_rate=16000, n_steps=-7)
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ data = audio.data
+ else:
+ data = audio
+
+ result = native.audio_pitch_shift(data._get_native(), sample_rate, n_steps, n_fft, hop_length)
+ return GPUArray._wrap_native(result)
+
+
+__all__ = [
+ "time_stretch",
+ "pitch_shift",
+]
diff --git a/src/pygpukit/ops/audio/features.py b/src/pygpukit/ops/audio/features.py
new file mode 100644
index 0000000..f76a4c9
--- /dev/null
+++ b/src/pygpukit/ops/audio/features.py
@@ -0,0 +1,199 @@
+"""Spectral feature extraction for GPU audio processing.
+
+This module provides:
+- Spectral centroid, bandwidth, rolloff, flatness, contrast
+- Zero-crossing rate
+"""
+
+from __future__ import annotations
+
+from pygpukit.core import GPUArray
+
+from .buffer import AudioBuffer
+
+
+def _get_native():
+ """Get the native module."""
+ try:
+ from pygpukit._native_loader import get_native_module
+
+ return get_native_module()
+ except ImportError:
+ from pygpukit import _pygpukit_native
+
+ return _pygpukit_native
+
+
+def spectral_centroid(
+ spectrum: GPUArray,
+ sample_rate: int = 16000,
+) -> GPUArray:
+ """Compute spectral centroid for each frame.
+
+ The spectral centroid indicates the "center of mass" of the spectrum.
+
+ Args:
+ spectrum: Magnitude or power spectrum [n_frames, n_freq]
+ sample_rate: Sample rate in Hz
+
+ Returns:
+ Spectral centroid in Hz for each frame [n_frames]
+
+ Example:
+ >>> mag = magnitude_spectrum(stft_out)
+ >>> centroid = spectral_centroid(mag, sample_rate=16000)
+ """
+ native = _get_native()
+ result = native.audio_spectral_centroid(spectrum._get_native(), sample_rate)
+ return GPUArray._wrap_native(result)
+
+
+def spectral_bandwidth(
+ spectrum: GPUArray,
+ centroids: GPUArray,
+ sample_rate: int = 16000,
+ p: int = 2,
+) -> GPUArray:
+ """Compute spectral bandwidth for each frame.
+
+ Spectral bandwidth is the weighted standard deviation of frequencies
+ around the spectral centroid.
+
+ Args:
+ spectrum: Magnitude or power spectrum [n_frames, n_freq]
+ centroids: Pre-computed spectral centroids [n_frames]
+ sample_rate: Sample rate in Hz
+ p: Order for bandwidth computation (default 2)
+
+ Returns:
+ Spectral bandwidth in Hz for each frame [n_frames]
+
+ Example:
+ >>> mag = magnitude_spectrum(stft_out)
+ >>> centroid = spectral_centroid(mag, sample_rate=16000)
+ >>> bandwidth = spectral_bandwidth(mag, centroid, sample_rate=16000)
+ """
+ native = _get_native()
+ result = native.audio_spectral_bandwidth(
+ spectrum._get_native(), centroids._get_native(), sample_rate, p
+ )
+ return GPUArray._wrap_native(result)
+
+
+def spectral_rolloff(
+ spectrum: GPUArray,
+ sample_rate: int = 16000,
+ roll_percent: float = 0.85,
+) -> GPUArray:
+ """Compute spectral rolloff for each frame.
+
+ The rolloff frequency is the frequency below which roll_percent of
+ the total spectral energy is contained.
+
+ Args:
+ spectrum: Magnitude or power spectrum [n_frames, n_freq]
+ sample_rate: Sample rate in Hz
+ roll_percent: Percentage of energy (default 0.85)
+
+ Returns:
+ Rolloff frequency in Hz for each frame [n_frames]
+
+ Example:
+ >>> mag = magnitude_spectrum(stft_out)
+ >>> rolloff = spectral_rolloff(mag, sample_rate=16000, roll_percent=0.85)
+ """
+ native = _get_native()
+ result = native.audio_spectral_rolloff(spectrum._get_native(), sample_rate, roll_percent)
+ return GPUArray._wrap_native(result)
+
+
+def spectral_flatness(spectrum: GPUArray) -> GPUArray:
+ """Compute spectral flatness for each frame.
+
+ Spectral flatness measures how tone-like vs noise-like a sound is.
+ Values close to 1 indicate noise, values close to 0 indicate tonal content.
+
+ Computed as: geometric_mean / arithmetic_mean
+
+ Args:
+ spectrum: Magnitude or power spectrum [n_frames, n_freq]
+
+ Returns:
+ Spectral flatness for each frame [n_frames] (0 to 1)
+
+ Example:
+ >>> mag = magnitude_spectrum(stft_out)
+ >>> flatness = spectral_flatness(mag)
+ """
+ native = _get_native()
+ result = native.audio_spectral_flatness(spectrum._get_native())
+ return GPUArray._wrap_native(result)
+
+
+def spectral_contrast(
+ spectrum: GPUArray,
+ n_bands: int = 6,
+ alpha: float = 0.2,
+) -> GPUArray:
+ """Compute spectral contrast for each frame.
+
+ Spectral contrast measures the difference between peaks and valleys
+ in the spectrum, divided into frequency bands.
+
+ Args:
+ spectrum: Magnitude or power spectrum [n_frames, n_freq]
+ n_bands: Number of frequency bands (default 6)
+ alpha: Percentile for peak/valley estimation (default 0.2)
+
+ Returns:
+ Spectral contrast [n_frames, n_bands]
+
+ Example:
+ >>> mag = magnitude_spectrum(stft_out)
+ >>> contrast = spectral_contrast(mag, n_bands=6)
+ """
+ native = _get_native()
+ result = native.audio_spectral_contrast(spectrum._get_native(), n_bands, alpha)
+ return GPUArray._wrap_native(result)
+
+
+def zero_crossing_rate(
+ audio: AudioBuffer | GPUArray,
+ frame_size: int = 512,
+ hop_size: int = 256,
+) -> GPUArray:
+ """Compute zero-crossing rate for each frame.
+
+ ZCR counts the number of times the signal crosses zero per frame,
+ normalized by frame size.
+
+ Args:
+ audio: Input audio (float32)
+ frame_size: Frame size in samples (default 512)
+ hop_size: Hop size in samples (default 256)
+
+ Returns:
+ Zero-crossing rate for each frame [n_frames]
+
+ Example:
+ >>> zcr = zero_crossing_rate(buf, frame_size=512, hop_size=256)
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ data = audio.data
+ else:
+ data = audio
+
+ result = native.audio_zero_crossing_rate(data._get_native(), frame_size, hop_size)
+ return GPUArray._wrap_native(result)
+
+
+__all__ = [
+ "spectral_centroid",
+ "spectral_bandwidth",
+ "spectral_rolloff",
+ "spectral_flatness",
+ "spectral_contrast",
+ "zero_crossing_rate",
+]
diff --git a/src/pygpukit/ops/audio/hpss.py b/src/pygpukit/ops/audio/hpss.py
new file mode 100644
index 0000000..dcd499f
--- /dev/null
+++ b/src/pygpukit/ops/audio/hpss.py
@@ -0,0 +1,108 @@
+"""Harmonic-Percussive Source Separation (HPSS) for GPU audio processing.
+
+This module provides:
+- HPSS (Harmonic-Percussive Source Separation)
+- Harmonic and percussive component extraction
+"""
+
+from __future__ import annotations
+
+from pygpukit.core import GPUArray
+
+
+def _get_native():
+ """Get the native module."""
+ try:
+ from pygpukit._native_loader import get_native_module
+
+ return get_native_module()
+ except ImportError:
+ from pygpukit import _pygpukit_native
+
+ return _pygpukit_native
+
+
+def hpss(
+ stft_magnitude_input: GPUArray,
+ kernel_size: int = 31,
+ power: float = 2.0,
+ margin: float = 1.0,
+) -> tuple[GPUArray, GPUArray]:
+ """Harmonic-Percussive Source Separation using median filtering.
+
+ Separates audio into harmonic (tonal) and percussive (transient) components
+ using median filtering in time and frequency directions.
+
+ Args:
+ stft_magnitude_input: STFT magnitude [n_frames, n_freq]
+ kernel_size: Median filter kernel size (default 31)
+ power: Power for spectrogram (default 2.0)
+ margin: Margin for soft masking (default 1.0)
+
+ Returns:
+ Tuple of (harmonic_magnitude, percussive_magnitude)
+
+ Example:
+ >>> mag = magnitude_spectrum(stft_out)
+ >>> harmonic, percussive = hpss(mag)
+ """
+ native = _get_native()
+ h, p = native.audio_hpss(stft_magnitude_input._get_native(), kernel_size, power, margin)
+ return GPUArray._wrap_native(h), GPUArray._wrap_native(p)
+
+
+def harmonic(
+ stft_magnitude_input: GPUArray,
+ kernel_size: int = 31,
+ power: float = 2.0,
+ margin: float = 1.0,
+) -> GPUArray:
+ """Extract harmonic component using HPSS.
+
+ Args:
+ stft_magnitude_input: STFT magnitude [n_frames, n_freq]
+ kernel_size: Median filter kernel size (default 31)
+ power: Power for spectrogram (default 2.0)
+ margin: Margin for soft masking (default 1.0)
+
+ Returns:
+ Harmonic magnitude [n_frames, n_freq]
+
+ Example:
+ >>> mag = magnitude_spectrum(stft_out)
+ >>> harm = harmonic(mag)
+ """
+ h, _ = hpss(stft_magnitude_input, kernel_size, power, margin)
+ return h
+
+
+def percussive(
+ stft_magnitude_input: GPUArray,
+ kernel_size: int = 31,
+ power: float = 2.0,
+ margin: float = 1.0,
+) -> GPUArray:
+ """Extract percussive component using HPSS.
+
+ Args:
+ stft_magnitude_input: STFT magnitude [n_frames, n_freq]
+ kernel_size: Median filter kernel size (default 31)
+ power: Power for spectrogram (default 2.0)
+ margin: Margin for soft masking (default 1.0)
+
+ Returns:
+ Percussive magnitude [n_frames, n_freq]
+
+ Example:
+ >>> mag = magnitude_spectrum(stft_out)
+ >>> perc = percussive(mag)
+ """
+ _, p = hpss(stft_magnitude_input, kernel_size, power, margin)
+ return p
+
+
+__all__ = [
+ "hpss",
+ "harmonic",
+ "percussive",
+]
diff --git a/src/pygpukit/ops/audio/phase.py b/src/pygpukit/ops/audio/phase.py
new file mode 100644
index 0000000..0434fed
--- /dev/null
+++ b/src/pygpukit/ops/audio/phase.py
@@ -0,0 +1,88 @@
+"""Phase reconstruction functions for GPU audio processing.
+
+This module provides:
+- ISTFT (Inverse Short-Time Fourier Transform)
+- Griffin-Lim algorithm for phase reconstruction
+"""
+
+from __future__ import annotations
+
+from pygpukit.core import GPUArray
+
+
+def _get_native():
+ """Get the native module."""
+ try:
+ from pygpukit._native_loader import get_native_module
+
+ return get_native_module()
+ except ImportError:
+ from pygpukit import _pygpukit_native
+
+ return _pygpukit_native
+
+
+def istft(
+ stft_output: GPUArray,
+ hop_length: int = 160,
+ win_length: int = -1,
+ center: bool = True,
+ length: int = -1,
+) -> GPUArray:
+ """Compute Inverse Short-Time Fourier Transform (ISTFT).
+
+ Reconstructs time-domain signal from complex STFT representation
+ using overlap-add with window sum normalization.
+
+ Args:
+ stft_output: Complex STFT [n_frames, n_freq, 2] (real, imag)
+ hop_length: Hop size (default 160)
+ win_length: Window length (default: (n_freq-1)*2)
+ center: Whether input was centered (default True)
+ length: Output length (-1 for automatic)
+
+ Returns:
+ Time-domain signal [n_samples]
+
+ Example:
+ >>> stft_out = stft(buf, n_fft=512, hop_length=160)
+ >>> reconstructed = istft(stft_out, hop_length=160)
+ """
+ native = _get_native()
+ result = native.audio_istft(stft_output._get_native(), hop_length, win_length, center, length)
+ return GPUArray._wrap_native(result)
+
+
+def griffin_lim(
+ magnitude: GPUArray,
+ n_iter: int = 32,
+ hop_length: int = 160,
+ win_length: int = -1,
+) -> GPUArray:
+ """Griffin-Lim algorithm for phase reconstruction.
+
+ Reconstructs time-domain signal from magnitude spectrogram only,
+ iteratively estimating phase using STFT/ISTFT consistency.
+
+ Args:
+ magnitude: Magnitude spectrogram [n_frames, n_freq]
+ n_iter: Number of iterations (default 32)
+ hop_length: Hop size (default 160)
+ win_length: Window length (default: (n_freq-1)*2)
+
+ Returns:
+ Reconstructed time-domain signal [n_samples]
+
+ Example:
+ >>> mag = magnitude_spectrum(stft_out)
+ >>> reconstructed = griffin_lim(mag, n_iter=32)
+ """
+ native = _get_native()
+ result = native.audio_griffin_lim(magnitude._get_native(), n_iter, hop_length, win_length)
+ return GPUArray._wrap_native(result)
+
+
+__all__ = [
+ "istft",
+ "griffin_lim",
+]
diff --git a/src/pygpukit/ops/audio/pitch.py b/src/pygpukit/ops/audio/pitch.py
new file mode 100644
index 0000000..74e651f
--- /dev/null
+++ b/src/pygpukit/ops/audio/pitch.py
@@ -0,0 +1,132 @@
+"""Pitch detection functions for GPU audio processing.
+
+This module provides:
+- Autocorrelation function
+- YIN pitch detection algorithm
+"""
+
+from __future__ import annotations
+
+from pygpukit.core import GPUArray
+
+from .buffer import AudioBuffer
+
+
+def _get_native():
+ """Get the native module."""
+ try:
+ from pygpukit._native_loader import get_native_module
+
+ return get_native_module()
+ except ImportError:
+ from pygpukit import _pygpukit_native
+
+ return _pygpukit_native
+
+
+def autocorrelation(audio: AudioBuffer | GPUArray, max_lag: int) -> GPUArray:
+ """Compute autocorrelation function.
+
+ Args:
+ audio: Input audio (float32)
+ max_lag: Maximum lag in samples
+
+ Returns:
+ Autocorrelation values [max_lag]
+
+ Example:
+ >>> acf = autocorrelation(buf, max_lag=400) # 25ms @ 16kHz
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ data = audio.data
+ else:
+ data = audio
+
+ result = native.audio_autocorrelation(data._get_native(), max_lag)
+ return GPUArray._wrap_native(result)
+
+
+def detect_pitch_yin(
+ audio: AudioBuffer | GPUArray,
+ sample_rate: int = 16000,
+ f_min: float = 50.0,
+ f_max: float = 500.0,
+ threshold: float = 0.1,
+) -> float:
+ """Detect pitch using YIN algorithm.
+
+ The YIN algorithm detects the fundamental frequency of a quasi-periodic
+ signal using cumulative mean normalized difference function.
+
+ Args:
+ audio: Input audio frame (float32)
+ sample_rate: Sample rate in Hz
+ f_min: Minimum frequency to detect (default 50 Hz)
+ f_max: Maximum frequency to detect (default 500 Hz)
+ threshold: YIN threshold (default 0.1)
+
+ Returns:
+ Detected pitch in Hz (0.0 if unvoiced)
+
+ Example:
+ >>> pitch = detect_pitch_yin(audio_frame, sample_rate=16000)
+ >>> if pitch > 0:
+ ... print(f"Pitch: {pitch:.1f} Hz")
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ data = audio.data
+ else:
+ data = audio
+
+ return native.audio_detect_pitch_yin(data._get_native(), sample_rate, f_min, f_max, threshold)
+
+
+def detect_pitch_yin_frames(
+ audio: AudioBuffer | GPUArray,
+ sample_rate: int = 16000,
+ frame_size: int = 1024,
+ hop_size: int = 256,
+ f_min: float = 50.0,
+ f_max: float = 500.0,
+ threshold: float = 0.1,
+) -> GPUArray:
+ """Detect pitch for each frame using YIN algorithm.
+
+ Args:
+ audio: Input audio (float32)
+ sample_rate: Sample rate in Hz
+ frame_size: Frame size in samples (default 1024)
+ hop_size: Hop size in samples (default 256)
+ f_min: Minimum frequency to detect (default 50 Hz)
+ f_max: Maximum frequency to detect (default 500 Hz)
+ threshold: YIN threshold (default 0.1)
+
+ Returns:
+ Pitch values for each frame [n_frames]
+
+ Example:
+ >>> pitches = detect_pitch_yin_frames(buf, sample_rate=16000)
+ >>> voiced = pitches.to_numpy() > 0
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ data = audio.data
+ else:
+ data = audio
+
+ result = native.audio_detect_pitch_yin_frames(
+ data._get_native(), sample_rate, frame_size, hop_size, f_min, f_max, threshold
+ )
+ return GPUArray._wrap_native(result)
+
+
+__all__ = [
+ "autocorrelation",
+ "detect_pitch_yin",
+ "detect_pitch_yin_frames",
+]
diff --git a/src/pygpukit/ops/audio/preprocessing.py b/src/pygpukit/ops/audio/preprocessing.py
new file mode 100644
index 0000000..02c89d4
--- /dev/null
+++ b/src/pygpukit/ops/audio/preprocessing.py
@@ -0,0 +1,249 @@
+"""Audio preprocessing functions for GPU audio processing.
+
+This module provides:
+- Pre-emphasis and de-emphasis filters
+- DC removal
+- High-pass filtering
+- Noise gate and spectral gate
+- Short-term energy computation
+"""
+
+from __future__ import annotations
+
+from pygpukit.core import GPUArray
+
+from .buffer import AudioBuffer
+
+
+def _get_native():
+ """Get the native module."""
+ try:
+ from pygpukit._native_loader import get_native_module
+
+ return get_native_module()
+ except ImportError:
+ from pygpukit import _pygpukit_native
+
+ return _pygpukit_native
+
+
+def preemphasis(audio: AudioBuffer | GPUArray, alpha: float = 0.97) -> AudioBuffer | GPUArray:
+ """Apply pre-emphasis filter to emphasize high-frequency components.
+
+ Pre-emphasis is commonly used in speech processing to boost high frequencies
+ that are typically attenuated during recording.
+
+ Formula: y[n] = x[n] - alpha * x[n-1]
+
+ Args:
+ audio: AudioBuffer or GPUArray of float32 samples
+ alpha: Pre-emphasis coefficient (default 0.97)
+
+ Returns:
+ Same type as input (modified in-place)
+
+ Example:
+ >>> buf = from_pcm(pcm_data, sample_rate=16000)
+ >>> preemphasis(buf, alpha=0.97)
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ native.audio_preemphasis(audio.data._get_native(), alpha)
+ return audio
+ else:
+ native.audio_preemphasis(audio._get_native(), alpha)
+ return audio
+
+
+def deemphasis(audio: AudioBuffer | GPUArray, alpha: float = 0.97) -> AudioBuffer | GPUArray:
+ """Apply de-emphasis filter (inverse of pre-emphasis).
+
+ Used to restore the original spectral balance after pre-emphasis.
+
+ Formula: y[n] = x[n] + alpha * y[n-1]
+
+ Args:
+ audio: AudioBuffer or GPUArray of float32 samples
+ alpha: De-emphasis coefficient (default 0.97)
+
+ Returns:
+ Same type as input (modified in-place)
+
+ Example:
+ >>> buf = preemphasis(buf)
+ >>> # ... processing ...
+ >>> deemphasis(buf) # Restore original balance
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ native.audio_deemphasis(audio.data._get_native(), alpha)
+ return audio
+ else:
+ native.audio_deemphasis(audio._get_native(), alpha)
+ return audio
+
+
+def remove_dc(audio: AudioBuffer | GPUArray) -> AudioBuffer | GPUArray:
+ """Remove DC offset from audio signal.
+
+ Subtracts the mean value from all samples, centering the signal at zero.
+ This is a simple but effective way to remove DC bias.
+
+ Args:
+ audio: AudioBuffer or GPUArray of float32 samples
+
+ Returns:
+ Same type as input (modified in-place)
+
+ Example:
+ >>> buf = from_pcm(pcm_data, sample_rate=16000)
+ >>> remove_dc(buf)
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ native.audio_remove_dc(audio.data._get_native())
+ return audio
+ else:
+ native.audio_remove_dc(audio._get_native())
+ return audio
+
+
+def highpass_filter(
+ audio: AudioBuffer | GPUArray,
+ cutoff_hz: float = 20.0,
+ sample_rate: int | None = None,
+) -> AudioBuffer | GPUArray:
+ """Apply high-pass filter for DC removal.
+
+ Uses a single-pole IIR high-pass filter, which is more effective than
+ simple mean subtraction for removing low-frequency noise.
+
+ Args:
+ audio: AudioBuffer or GPUArray of float32 samples
+ cutoff_hz: Cutoff frequency in Hz (default 20.0)
+ sample_rate: Sample rate in Hz (auto-detected from AudioBuffer)
+
+ Returns:
+ Same type as input (modified in-place)
+
+ Example:
+ >>> buf = from_pcm(pcm_data, sample_rate=16000)
+ >>> highpass_filter(buf, cutoff_hz=50.0) # Remove hum
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ sr = sample_rate if sample_rate is not None else audio.sample_rate
+ native.audio_highpass_filter(audio.data._get_native(), cutoff_hz, sr)
+ return audio
+ else:
+ sr = sample_rate if sample_rate is not None else 16000
+ native.audio_highpass_filter(audio._get_native(), cutoff_hz, sr)
+ return audio
+
+
+def noise_gate(audio: AudioBuffer | GPUArray, threshold: float = 0.01) -> AudioBuffer | GPUArray:
+ """Apply simple noise gate.
+
+ Zeros samples with absolute value below threshold. This is a hard gate
+ that completely silences quiet sections.
+
+ Args:
+ audio: AudioBuffer or GPUArray of float32 samples
+ threshold: Amplitude threshold (default 0.01)
+
+ Returns:
+ Same type as input (modified in-place)
+
+ Example:
+ >>> buf = from_pcm(pcm_data, sample_rate=16000)
+ >>> noise_gate(buf, threshold=0.02)
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ native.audio_noise_gate(audio.data._get_native(), threshold)
+ return audio
+ else:
+ native.audio_noise_gate(audio._get_native(), threshold)
+ return audio
+
+
+def spectral_gate(
+ audio: AudioBuffer | GPUArray,
+ threshold: float = 0.01,
+ attack_samples: int = 64,
+ release_samples: int = 256,
+) -> AudioBuffer | GPUArray:
+ """Apply spectral gate for noise reduction.
+
+ A softer noise gate that attenuates (rather than silences) quiet sections
+ based on short-term frame energy. Provides smoother transitions than
+ a hard noise gate.
+
+ Args:
+ audio: AudioBuffer or GPUArray of float32 samples
+ threshold: Energy threshold (linear scale, default 0.01)
+ attack_samples: Frame size for energy computation (default 64)
+ release_samples: Smoothing release in samples (default 256)
+
+ Returns:
+ Same type as input (modified in-place)
+
+ Example:
+ >>> buf = from_pcm(pcm_data, sample_rate=16000)
+ >>> spectral_gate(buf, threshold=0.005) # Subtle noise reduction
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ native.audio_spectral_gate(
+ audio.data._get_native(), threshold, attack_samples, release_samples
+ )
+ return audio
+ else:
+ native.audio_spectral_gate(audio._get_native(), threshold, attack_samples, release_samples)
+ return audio
+
+
+def compute_short_term_energy(audio: AudioBuffer | GPUArray, frame_size: int = 256) -> GPUArray:
+ """Compute short-term energy for analysis or adaptive processing.
+
+ Divides the audio into non-overlapping frames and computes the mean
+ energy (sum of squares / frame_size) for each frame.
+
+ Args:
+ audio: AudioBuffer or GPUArray of float32 samples
+ frame_size: Frame size in samples (default 256)
+
+ Returns:
+ GPUArray of frame energies
+
+ Example:
+ >>> buf = from_pcm(pcm_data, sample_rate=16000)
+ >>> energy = compute_short_term_energy(buf, frame_size=320) # 20ms @ 16kHz
+ >>> print(f"Max energy: {energy.to_numpy().max():.4f}")
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ data = audio.data
+ else:
+ data = audio
+
+ result = native.audio_compute_short_term_energy(data._get_native(), frame_size)
+ return GPUArray._wrap_native(result)
+
+
+__all__ = [
+ "preemphasis",
+ "deemphasis",
+ "remove_dc",
+ "highpass_filter",
+ "noise_gate",
+ "spectral_gate",
+ "compute_short_term_energy",
+]
diff --git a/src/pygpukit/ops/audio/spectral.py b/src/pygpukit/ops/audio/spectral.py
new file mode 100644
index 0000000..21254a4
--- /dev/null
+++ b/src/pygpukit/ops/audio/spectral.py
@@ -0,0 +1,338 @@
+"""Spectral processing functions for GPU audio processing.
+
+This module provides:
+- STFT (Short-Time Fourier Transform)
+- Power and magnitude spectrum
+- Mel filterbank operations
+- Log-mel spectrogram
+- MFCC
+- Delta features
+"""
+
+from __future__ import annotations
+
+from pygpukit.core import GPUArray
+
+from .buffer import AudioBuffer
+
+
+def _get_native():
+ """Get the native module."""
+ try:
+ from pygpukit._native_loader import get_native_module
+
+ return get_native_module()
+ except ImportError:
+ from pygpukit import _pygpukit_native
+
+ return _pygpukit_native
+
+
+def stft(
+ audio: AudioBuffer | GPUArray,
+ n_fft: int = 512,
+ hop_length: int = 160,
+ win_length: int = -1,
+ center: bool = True,
+) -> GPUArray:
+ """Compute Short-Time Fourier Transform (STFT).
+
+ Uses a custom Radix-2 FFT implementation (no cuFFT dependency).
+
+ Args:
+ audio: AudioBuffer or GPUArray of float32 samples
+ n_fft: FFT size (must be power of 2, default 512)
+ hop_length: Hop size (default 160)
+ win_length: Window length (default n_fft)
+ center: Whether to pad input with reflection (default True)
+
+ Returns:
+ Complex STFT output [n_frames, n_fft/2+1, 2] (real, imag)
+
+ Example:
+ >>> buf = from_pcm(pcm_data, sample_rate=16000)
+ >>> stft_out = stft(buf, n_fft=512, hop_length=160)
+ >>> print(f"STFT shape: {stft_out.shape}") # [n_frames, 257, 2]
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ data = audio.data
+ else:
+ data = audio
+
+ result = native.audio_stft(data._get_native(), n_fft, hop_length, win_length, center)
+ return GPUArray._wrap_native(result)
+
+
+def power_spectrum(stft_output: GPUArray) -> GPUArray:
+ """Compute power spectrogram from STFT output.
+
+ power = real^2 + imag^2
+
+ Args:
+ stft_output: STFT output [n_frames, n_freq, 2]
+
+ Returns:
+ Power spectrogram [n_frames, n_freq]
+
+ Example:
+ >>> stft_out = stft(buf, n_fft=512)
+ >>> power = power_spectrum(stft_out)
+ """
+ native = _get_native()
+ result = native.audio_power_spectrum(stft_output._get_native())
+ return GPUArray._wrap_native(result)
+
+
+def magnitude_spectrum(stft_output: GPUArray) -> GPUArray:
+ """Compute magnitude spectrogram from STFT output.
+
+ magnitude = sqrt(real^2 + imag^2)
+
+ Args:
+ stft_output: STFT output [n_frames, n_freq, 2]
+
+ Returns:
+ Magnitude spectrogram [n_frames, n_freq]
+
+ Example:
+ >>> stft_out = stft(buf, n_fft=512)
+ >>> mag = magnitude_spectrum(stft_out)
+ """
+ native = _get_native()
+ result = native.audio_magnitude_spectrum(stft_output._get_native())
+ return GPUArray._wrap_native(result)
+
+
+def create_mel_filterbank(
+ n_mels: int = 80,
+ n_fft: int = 512,
+ sample_rate: int = 16000,
+ f_min: float = 0.0,
+ f_max: float = -1.0,
+) -> GPUArray:
+ """Create Mel filterbank matrix.
+
+ Args:
+ n_mels: Number of mel bands (default 80 for Whisper)
+ n_fft: FFT size
+ sample_rate: Sample rate in Hz
+ f_min: Minimum frequency (default 0)
+ f_max: Maximum frequency (default sample_rate/2)
+
+ Returns:
+ Mel filterbank matrix [n_mels, n_fft/2+1]
+
+ Example:
+ >>> mel_fb = create_mel_filterbank(n_mels=80, n_fft=512, sample_rate=16000)
+ """
+ native = _get_native()
+ result = native.audio_create_mel_filterbank(n_mels, n_fft, sample_rate, f_min, f_max)
+ return GPUArray._wrap_native(result)
+
+
+def apply_mel_filterbank(spectrogram: GPUArray, mel_filterbank: GPUArray) -> GPUArray:
+ """Apply Mel filterbank to power/magnitude spectrogram.
+
+ Args:
+ spectrogram: Input spectrogram [n_frames, n_fft/2+1]
+ mel_filterbank: Mel filterbank [n_mels, n_fft/2+1]
+
+ Returns:
+ Mel spectrogram [n_frames, n_mels]
+
+ Example:
+ >>> power = power_spectrum(stft_out)
+ >>> mel_fb = create_mel_filterbank(n_mels=80, n_fft=512)
+ >>> mel = apply_mel_filterbank(power, mel_fb)
+ """
+ native = _get_native()
+ result = native.audio_apply_mel_filterbank(
+ spectrogram._get_native(), mel_filterbank._get_native()
+ )
+ return GPUArray._wrap_native(result)
+
+
+def log_mel(mel_spectrogram: GPUArray, eps: float = 1e-10) -> GPUArray:
+ """Compute log-mel spectrogram.
+
+ log_mel = log(mel + eps)
+
+ Args:
+ mel_spectrogram: Mel spectrogram [n_frames, n_mels]
+ eps: Small constant for numerical stability (default 1e-10)
+
+ Returns:
+ Log-mel spectrogram [n_frames, n_mels]
+
+ Example:
+ >>> log_mel_spec = log_mel(mel_spectrogram)
+ """
+ native = _get_native()
+ result = native.audio_log_mel_spectrogram(mel_spectrogram._get_native(), eps)
+ return GPUArray._wrap_native(result)
+
+
+def to_decibels(audio: AudioBuffer | GPUArray, eps: float = 1e-10) -> GPUArray:
+ """Convert to decibels.
+
+ dB = 10 * log10(x + eps)
+
+ Args:
+ audio: Input array (power values)
+ eps: Small constant for numerical stability (default 1e-10)
+
+ Returns:
+ dB values
+
+ Example:
+ >>> power = power_spectrum(stft_out)
+ >>> db = to_decibels(power)
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ data = audio.data
+ else:
+ data = audio
+
+ result = native.audio_to_decibels(data._get_native(), eps)
+ return GPUArray._wrap_native(result)
+
+
+def mfcc(log_mel_input: GPUArray, n_mfcc: int = 13) -> GPUArray:
+ """Compute MFCC from log-mel spectrogram using DCT-II.
+
+ Args:
+ log_mel_input: Log-mel spectrogram [n_frames, n_mels]
+ n_mfcc: Number of MFCC coefficients (default 13)
+
+ Returns:
+ MFCC [n_frames, n_mfcc]
+
+ Example:
+ >>> log_mel_spec = log_mel(mel_spectrogram)
+ >>> mfcc_features = mfcc(log_mel_spec, n_mfcc=13)
+ """
+ native = _get_native()
+ result = native.audio_mfcc(log_mel_input._get_native(), n_mfcc)
+ return GPUArray._wrap_native(result)
+
+
+def delta(features: GPUArray, order: int = 1, width: int = 2) -> GPUArray:
+ """Compute delta (differential) features.
+
+ Args:
+ features: Input features [n_frames, n_features]
+ order: Delta order (1 for delta, 2 for delta-delta)
+ width: Window width for computation (default 2)
+
+ Returns:
+ Delta features [n_frames, n_features]
+
+ Example:
+ >>> mfcc_features = mfcc(log_mel_spec)
+ >>> delta_mfcc = delta(mfcc_features, order=1)
+ >>> delta_delta_mfcc = delta(mfcc_features, order=2)
+ """
+ native = _get_native()
+ result = native.audio_delta_features(features._get_native(), order, width)
+ return GPUArray._wrap_native(result)
+
+
+def mel_spectrogram(
+ audio: AudioBuffer | GPUArray,
+ n_fft: int = 512,
+ hop_length: int = 160,
+ n_mels: int = 80,
+ sample_rate: int = 16000,
+ f_min: float = 0.0,
+ f_max: float = -1.0,
+) -> GPUArray:
+ """Compute mel spectrogram.
+
+ Combines: STFT -> power -> mel filterbank
+
+ Args:
+ audio: Input audio (float32)
+ n_fft: FFT size (must be power of 2)
+ hop_length: Hop size
+ n_mels: Number of mel bands
+ sample_rate: Sample rate in Hz
+ f_min: Minimum frequency
+ f_max: Maximum frequency (-1 for sample_rate/2)
+
+ Returns:
+ Mel spectrogram [n_frames, n_mels]
+
+ Example:
+ >>> mel = mel_spectrogram(buf, n_fft=512, hop_length=160, n_mels=80)
+ """
+ if isinstance(audio, AudioBuffer):
+ data = audio.data
+ else:
+ data = audio
+
+ # STFT
+ stft_out = stft(data, n_fft=n_fft, hop_length=hop_length, center=True)
+
+ # Power spectrum
+ power = power_spectrum(stft_out)
+
+ # Create and apply mel filterbank
+ mel_fb = create_mel_filterbank(n_mels, n_fft, sample_rate, f_min, f_max)
+ mel = apply_mel_filterbank(power, mel_fb)
+
+ return mel
+
+
+def log_mel_spectrogram(
+ audio: AudioBuffer | GPUArray,
+ n_fft: int = 512,
+ hop_length: int = 160,
+ n_mels: int = 80,
+ sample_rate: int = 16000,
+ f_min: float = 0.0,
+ f_max: float = -1.0,
+ eps: float = 1e-10,
+) -> GPUArray:
+ """Compute log-mel spectrogram (Whisper-compatible).
+
+ Combines: STFT -> power -> mel filterbank -> log
+
+ Args:
+ audio: Input audio (float32, 16kHz expected for Whisper)
+ n_fft: FFT size (must be power of 2)
+ hop_length: Hop size
+ n_mels: Number of mel bands (80 for Whisper)
+ sample_rate: Sample rate in Hz
+ f_min: Minimum frequency
+ f_max: Maximum frequency (-1 for sample_rate/2)
+ eps: Small constant for log stability
+
+ Returns:
+ Log-mel spectrogram [n_frames, n_mels]
+
+ Example:
+ >>> # Whisper-style mel spectrogram
+ >>> buf = from_pcm(pcm_data, sample_rate=16000)
+ >>> log_mel = log_mel_spectrogram(buf, n_fft=512, hop_length=160, n_mels=80)
+ """
+ mel = mel_spectrogram(audio, n_fft, hop_length, n_mels, sample_rate, f_min, f_max)
+ return log_mel(mel, eps)
+
+
+__all__ = [
+ "stft",
+ "power_spectrum",
+ "magnitude_spectrum",
+ "create_mel_filterbank",
+ "apply_mel_filterbank",
+ "log_mel",
+ "to_decibels",
+ "mfcc",
+ "delta",
+ "mel_spectrogram",
+ "log_mel_spectrogram",
+]
diff --git a/src/pygpukit/ops/audio/vad.py b/src/pygpukit/ops/audio/vad.py
new file mode 100644
index 0000000..68883ca
--- /dev/null
+++ b/src/pygpukit/ops/audio/vad.py
@@ -0,0 +1,223 @@
+"""Voice Activity Detection (VAD) for GPU audio processing.
+
+This module provides:
+- VAD: GPU-accelerated Voice Activity Detection
+- SpeechSegment: Detected speech segment data class
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import numpy as np
+
+from pygpukit.core import GPUArray
+
+from .buffer import AudioBuffer
+
+
+def _get_native():
+ """Get the native module."""
+ try:
+ from pygpukit._native_loader import get_native_module
+
+ return get_native_module()
+ except ImportError:
+ from pygpukit import _pygpukit_native
+
+ return _pygpukit_native
+
+
+@dataclass
+class SpeechSegment:
+ """Represents a detected speech segment.
+
+ Attributes:
+ start_sample: Start sample index
+ end_sample: End sample index
+ start_time: Start time in seconds
+ end_time: End time in seconds
+ """
+
+ start_sample: int
+ end_sample: int
+ start_time: float
+ end_time: float
+
+
+class VAD:
+ """GPU-accelerated Voice Activity Detection.
+
+ Detects speech segments in audio using energy and zero-crossing rate features.
+ Supports adaptive thresholding and hangover smoothing for robust detection.
+
+ Args:
+ sample_rate: Audio sample rate in Hz (default: 16000)
+ frame_ms: Frame duration in milliseconds (default: 20)
+ hop_ms: Hop duration in milliseconds (default: 10)
+ energy_threshold: Energy threshold for speech (default: auto)
+ hangover_ms: Hangover duration in milliseconds (default: 100)
+
+ Example:
+ >>> vad = VAD(sample_rate=16000)
+ >>> segments = vad.detect(audio_buffer)
+ >>> for seg in segments:
+ ... print(f"Speech: {seg.start_time:.2f}s - {seg.end_time:.2f}s")
+ """
+
+ def __init__(
+ self,
+ sample_rate: int = 16000,
+ frame_ms: float = 20.0,
+ hop_ms: float = 10.0,
+ energy_threshold: float | None = None,
+ hangover_ms: float = 100.0,
+ zcr_low: float = 0.02,
+ zcr_high: float = 0.25,
+ ):
+ self._sample_rate = sample_rate
+ self._frame_size = int(frame_ms * sample_rate / 1000)
+ self._hop_size = int(hop_ms * sample_rate / 1000)
+ self._energy_threshold = energy_threshold
+ self._hangover_frames = int(hangover_ms / hop_ms)
+ self._zcr_low = zcr_low
+ self._zcr_high = zcr_high
+
+ # Adaptive threshold multiplier (above noise floor)
+ self._adaptive_multiplier = 3.0
+
+ @property
+ def sample_rate(self) -> int:
+ """Sample rate in Hz."""
+ return self._sample_rate
+
+ @property
+ def frame_size(self) -> int:
+ """Frame size in samples."""
+ return self._frame_size
+
+ @property
+ def hop_size(self) -> int:
+ """Hop size in samples."""
+ return self._hop_size
+
+ def detect(self, audio: AudioBuffer | GPUArray) -> list[SpeechSegment]:
+ """Detect speech segments in audio.
+
+ Args:
+ audio: AudioBuffer or GPUArray of float32 samples
+
+ Returns:
+ List of SpeechSegment objects representing detected speech regions
+ """
+ native = _get_native()
+
+ # Get audio data
+ if isinstance(audio, AudioBuffer):
+ data = audio.data
+ else:
+ data = audio
+
+ # Compute frame features
+ energy = native.vad_compute_energy(data._get_native(), self._frame_size, self._hop_size)
+ zcr = native.vad_compute_zcr(data._get_native(), self._frame_size, self._hop_size)
+
+ energy_gpu = GPUArray._wrap_native(energy)
+ zcr_gpu = GPUArray._wrap_native(zcr)
+
+ # Determine energy threshold
+ if self._energy_threshold is not None:
+ threshold = self._energy_threshold
+ else:
+ # Adaptive threshold: multiplier * noise_floor
+ noise_floor = native.vad_compute_noise_floor(energy)
+ threshold = max(noise_floor * self._adaptive_multiplier, 0.01)
+
+ # VAD decision
+ vad_flags = native.vad_decide(
+ energy_gpu._get_native(),
+ zcr_gpu._get_native(),
+ threshold,
+ self._zcr_low,
+ self._zcr_high,
+ )
+ vad_flags_gpu = GPUArray._wrap_native(vad_flags)
+
+ # Apply hangover smoothing
+ if self._hangover_frames > 0:
+ smoothed = native.vad_apply_hangover(vad_flags_gpu._get_native(), self._hangover_frames)
+ vad_flags_gpu = GPUArray._wrap_native(smoothed)
+
+ # Convert to segments
+ return self._flags_to_segments(vad_flags_gpu)
+
+ def _flags_to_segments(self, vad_flags: GPUArray) -> list[SpeechSegment]:
+ """Convert frame-level VAD flags to speech segments."""
+ flags: np.ndarray = vad_flags.to_numpy().astype(int)
+
+ segments: list[SpeechSegment] = []
+ in_speech = False
+ start_frame = 0
+
+ for i, flag in enumerate(flags):
+ if flag == 1 and not in_speech:
+ # Speech start
+ in_speech = True
+ start_frame = i
+ elif flag == 0 and in_speech:
+ # Speech end
+ in_speech = False
+ segments.append(self._create_segment(start_frame, i))
+
+ # Handle case where speech continues to end
+ if in_speech:
+ segments.append(self._create_segment(start_frame, len(flags)))
+
+ return segments
+
+ def _create_segment(self, start_frame: int, end_frame: int) -> SpeechSegment:
+ """Create a SpeechSegment from frame indices."""
+ start_sample = start_frame * self._hop_size
+ end_sample = end_frame * self._hop_size + self._frame_size
+
+ return SpeechSegment(
+ start_sample=start_sample,
+ end_sample=end_sample,
+ start_time=start_sample / self._sample_rate,
+ end_time=end_sample / self._sample_rate,
+ )
+
+ def get_frame_features(self, audio: AudioBuffer | GPUArray) -> tuple[GPUArray, GPUArray]:
+ """Get raw frame features (energy and ZCR) for analysis.
+
+ Args:
+ audio: AudioBuffer or GPUArray of float32 samples
+
+ Returns:
+ Tuple of (energy, zcr) GPUArrays
+ """
+ native = _get_native()
+
+ if isinstance(audio, AudioBuffer):
+ data = audio.data
+ else:
+ data = audio
+
+ energy = native.vad_compute_energy(data._get_native(), self._frame_size, self._hop_size)
+ zcr = native.vad_compute_zcr(data._get_native(), self._frame_size, self._hop_size)
+
+ return GPUArray._wrap_native(energy), GPUArray._wrap_native(zcr)
+
+ def __repr__(self) -> str:
+ return (
+ f"VAD(sample_rate={self._sample_rate}, "
+ f"frame_size={self._frame_size}, "
+ f"hop_size={self._hop_size}, "
+ f"hangover_frames={self._hangover_frames})"
+ )
+
+
+__all__ = [
+ "SpeechSegment",
+ "VAD",
+]
diff --git a/src/pygpukit/ops/matmul.py b/src/pygpukit/ops/matmul.py
deleted file mode 100644
index 791667b..0000000
--- a/src/pygpukit/ops/matmul.py
+++ /dev/null
@@ -1,2087 +0,0 @@
-"""Matrix multiplication operations for GPUArrays.
-
-Corresponds to native/ops/matmul/.
-"""
-
-from __future__ import annotations
-
-import warnings
-
-import numpy as np
-
-from pygpukit.core.array import GPUArray
-from pygpukit.core.backend import NativeBackend, get_backend
-from pygpukit.core.factory import from_numpy
-from pygpukit.ops._common import _validate_float_dtype, _validate_same_dtype
-
-
-def matmul(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
- use_tf32: bool | None = None,
-) -> GPUArray:
- """Matrix multiplication of two 2D arrays.
-
- Args:
- a: First input array (M x K).
- b: Second input array (K x N).
- out: Optional output array (M x N). If provided, result is written to this
- array instead of allocating a new one. This enables CUDA Graph capture
- since no memory allocation occurs during the operation.
- use_tf32: Whether to use TF32 TensorCore acceleration (Ampere+ only).
- - None (default): Use PYGPUKIT_ALLOW_TF32 environment variable
- - True: Force TF32 mode (requires SM >= 80 and float32)
- - False: Force FP32 mode
-
- Returns:
- The result GPUArray (M x N). If out is provided, returns out.
-
- Raises:
- ValueError: If arrays are not 2D or dimensions don't match.
- RuntimeError: If use_tf32=True but GPU doesn't support it or dtype is not float32.
-
- Example:
- # Allocate new output
- y = pk.matmul(x, W)
-
- # Write to existing buffer (for CUDA Graph capture)
- pk.matmul(x, W, out=y)
- """
- if a.ndim != 2:
- raise ValueError(f"matmul requires 2D arrays, got {a.ndim}D for first argument")
- if b.ndim != 2:
- raise ValueError(f"matmul requires 2D arrays, got {b.ndim}D for second argument")
-
- if a.shape[1] != b.shape[0]:
- raise ValueError(
- f"matmul dimension mismatch: {a.shape} @ {b.shape} "
- f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)"
- )
-
- _validate_same_dtype(a, b, "matmul")
-
- # Validate out array if provided
- if out is not None:
- expected_shape = (a.shape[0], b.shape[1])
- if out.shape != expected_shape:
- raise ValueError(f"out shape {out.shape} does not match expected {expected_shape}")
- if out.dtype != a.dtype:
- raise ValueError(f"out dtype {out.dtype} does not match input dtype {a.dtype}")
-
- # Check TF32 dtype requirement early (before backend dispatch)
- if use_tf32 is True:
- from pygpukit.core.dtypes import float32
-
- if a.dtype != float32:
- raise RuntimeError("TF32 matmul requires float32 dtype")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _matmul_native(a, b, out=out, use_tf32=use_tf32)
- else:
- return _matmul_cpu(a, b, out=out)
-
-
-def _matmul_cpu(a: GPUArray, b: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
- """CPU implementation of matmul."""
- a_np = a.to_numpy()
- b_np = b.to_numpy()
- if out is not None:
- out_np = out.to_numpy()
- np.matmul(a_np, b_np, out=out_np)
- # Copy back to GPU - this is inefficient but CPU backend is for fallback only
- out._data = from_numpy(out_np)._data
- return out
- else:
- result_np = np.matmul(a_np, b_np)
- return from_numpy(result_np)
-
-
-def _matmul_native(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
- use_tf32: bool | None = None,
-) -> GPUArray:
- """Native C++ CUDA implementation of matmul (zero-copy).
-
- Args:
- a: First input array.
- b: Second input array.
- out: Optional output array. If provided, result is written in-place.
- use_tf32: Whether to use TF32 TensorCore acceleration.
- None means use environment variable PYGPUKIT_ALLOW_TF32.
- """
-
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- # Get native arrays (zero-copy if already native)
- a_native = a._get_native()
- b_native = b._get_native()
-
- if out is not None:
- # In-place operation - write to existing buffer
- out_native = out._get_native()
- if use_tf32 is not None:
- native.matmul_tf32_(a_native, b_native, out_native, use_tf32)
- else:
- native.matmul_(a_native, b_native, out_native)
- return out
- else:
- # Allocate new output
- if use_tf32 is not None:
- c_native = native.matmul_tf32(a_native, b_native, use_tf32)
- else:
- c_native = native.matmul(a_native, b_native)
- return GPUArray._wrap_native(c_native)
-
-
-def transpose(a: GPUArray) -> GPUArray:
- """Matrix transpose.
-
- Args:
- a: Input array of shape [rows, cols].
-
- Returns:
- A new GPUArray of shape [cols, rows] containing a.T.
-
- Raises:
- ValueError: If input is not 2D.
- """
- if a.ndim != 2:
- raise ValueError(f"transpose expects 2D input [rows, cols], got {a.ndim}D")
-
- from pygpukit.core.dtypes import uint8
-
- backend = get_backend()
-
- # For uint8 (FP8 weights), use CPU fallback since native transpose
- # doesn't support integer types
- if a.dtype == uint8:
- return _transpose_cpu(a)
-
- _validate_float_dtype(a, "transpose")
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _transpose_native(a)
- else:
- return _transpose_cpu(a)
-
-
-def _transpose_cpu(a: GPUArray) -> GPUArray:
- """CPU implementation of transpose."""
- a_np = a.to_numpy()
- return from_numpy(a_np.T.copy())
-
-
-def _transpose_native(a: GPUArray) -> GPUArray:
- """Native C++ CUDA implementation of transpose (zero-copy)."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- a_native = a._get_native()
- c_native = native.transpose(a_native)
- return GPUArray._wrap_native(c_native)
-
-
-def linear_bias_gelu(
- input: GPUArray,
- weight: GPUArray,
- bias: GPUArray,
-) -> GPUArray:
- """Fused linear + bias + GELU operation.
-
- Computes: output = gelu(input @ weight^T + bias)
-
- When dimensions are multiples of 16, this uses CUTLASS TensorCore
- epilogue fusion for efficiency. Otherwise, falls back to separate
- matmul + bias_add + gelu operations.
-
- Args:
- input: Input array of shape [batch, in_features].
- weight: Weight array of shape [out_features, in_features].
- bias: Bias array of shape [out_features].
-
- Returns:
- A new GPUArray of shape [batch, out_features].
-
- Raises:
- ValueError: If shapes or dtypes don't match.
-
- Note:
- Best performance when dimensions are multiples of 16 (uses TensorCore).
- Non-aligned dimensions use native fallback path.
- """
- _validate_float_dtype(input, "linear_bias_gelu")
-
- if input.ndim != 2:
- raise ValueError(
- f"linear_bias_gelu expects 2D input [batch, in_features], got {input.ndim}D"
- )
- if weight.ndim != 2:
- raise ValueError(
- f"linear_bias_gelu expects 2D weight [out_features, in_features], got {weight.ndim}D"
- )
- if bias.ndim != 1:
- raise ValueError(f"linear_bias_gelu expects 1D bias [out_features], got {bias.ndim}D")
-
- if input.dtype != weight.dtype or input.dtype != bias.dtype:
- raise ValueError("linear_bias_gelu: all inputs must have same dtype")
-
- in_features = input.shape[1]
- out_features = weight.shape[0]
-
- if weight.shape[1] != in_features:
- raise ValueError(
- f"linear_bias_gelu: weight.shape[1]={weight.shape[1]} must match "
- f"input.shape[1]={in_features}"
- )
- if bias.shape[0] != out_features:
- raise ValueError(
- f"linear_bias_gelu: bias.shape[0]={bias.shape[0]} must match "
- f"weight.shape[0]={out_features}"
- )
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _linear_bias_gelu_native(input, weight, bias)
- else:
- return _linear_bias_gelu_cpu(input, weight, bias)
-
-
-def _linear_bias_gelu_cpu(
- input: GPUArray,
- weight: GPUArray,
- bias: GPUArray,
-) -> GPUArray:
- """CPU implementation of linear_bias_gelu."""
- x = input.to_numpy()
- w = weight.to_numpy()
- b = bias.to_numpy()
-
- # Linear: y = x @ w.T + b
- y = x @ w.T + b
-
- # GELU approximation (same as GPU kernel)
- sqrt_2_over_pi = np.sqrt(2.0 / np.pi)
- result = y * 0.5 * (1.0 + np.tanh(sqrt_2_over_pi * (y + 0.044715 * y**3)))
-
- return from_numpy(result.astype(x.dtype))
-
-
-def _linear_bias_gelu_native(
- input: GPUArray,
- weight: GPUArray,
- bias: GPUArray,
-) -> GPUArray:
- """Native C++ CUDA implementation of linear_bias_gelu (CUTLASS fused kernel)."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- input_native = input._get_native()
- weight_native = weight._get_native()
- bias_native = bias._get_native()
- c_native = native.linear_bias_gelu(input_native, weight_native, bias_native)
- return GPUArray._wrap_native(c_native)
-
-
-def batched_matmul(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Batched matrix multiplication for 3D and 4D tensors.
-
- Supports:
- - 3D: [batch, M, K] @ [batch, K, N] -> [batch, M, N]
- - 4D: [batch1, batch2, M, K] @ [batch1, batch2, K, N] -> [batch1, batch2, M, N]
-
- Args:
- a: First input array (3D or 4D).
- b: Second input array (3D or 4D).
- out: Optional output array. If provided, result is written in-place.
-
- Returns:
- The result GPUArray with shape [..., M, N].
-
- Raises:
- ValueError: If arrays are not 3D/4D or dimensions don't match.
- """
- if a.ndim not in (3, 4):
- raise ValueError(f"batched_matmul requires 3D or 4D arrays, got {a.ndim}D")
- if b.ndim not in (3, 4):
- raise ValueError(f"batched_matmul requires 3D or 4D arrays, got {b.ndim}D")
- if a.ndim != b.ndim:
- raise ValueError(f"batched_matmul requires same ndim, got {a.ndim}D and {b.ndim}D")
-
- _validate_same_dtype(a, b, "batched_matmul")
-
- # Extract dimensions
- if a.ndim == 3:
- batch = a.shape[0]
- M, K = a.shape[1], a.shape[2]
- K2, N = b.shape[1], b.shape[2]
- if b.shape[0] != batch:
- raise ValueError(f"Batch dimension mismatch: {a.shape[0]} vs {b.shape[0]}")
- if K != K2:
- raise ValueError(f"Inner dimension mismatch: {K} vs {K2}")
- out_shape = (batch, M, N)
- batch_count = batch
- else: # 4D
- batch1, batch2 = a.shape[0], a.shape[1]
- M, K = a.shape[2], a.shape[3]
- K2, N = b.shape[2], b.shape[3]
- if b.shape[0] != batch1 or b.shape[1] != batch2:
- raise ValueError(
- f"Batch dimensions mismatch: ({batch1}, {batch2}) vs ({b.shape[0]}, {b.shape[1]})"
- )
- if K != K2:
- raise ValueError(f"Inner dimension mismatch: {K} vs {K2}")
- out_shape = (batch1, batch2, M, N)
- batch_count = batch1 * batch2
-
- # Validate output
- if out is not None:
- if out.shape != out_shape:
- raise ValueError(f"out shape {out.shape} does not match expected {out_shape}")
- if out.dtype != a.dtype:
- raise ValueError(f"out dtype {out.dtype} does not match input dtype {a.dtype}")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _batched_matmul_native(a, b, M, N, K, batch_count, out_shape, out=out)
- else:
- return _batched_matmul_cpu(a, b, out=out)
-
-
-def _batched_matmul_cpu(a: GPUArray, b: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
- """CPU implementation of batched_matmul."""
- a_np = a.to_numpy()
- b_np = b.to_numpy()
- result_np = np.matmul(a_np, b_np)
- result = from_numpy(result_np)
-
- if out is not None:
- # Copy result to output buffer
- from ..ops.elementwise import copy_to
-
- copy_to(result, out)
- return out
- else:
- return result
-
-
-def _batched_matmul_loop(
- a: GPUArray, b: GPUArray, out_shape: tuple[int, ...], *, out: GPUArray | None = None
-) -> GPUArray:
- """GPU batched matmul using loop over individual matmuls.
-
- This is a fallback for when CUTLASS strided batched GEMM is not available
- (e.g., SM 120). Uses native matmul kernel for each batch element.
- """
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- # Reshape to 3D for easier iteration: [batch, M, K] @ [batch, K, N]
- if a.ndim == 4:
- batch1, batch2 = a.shape[0], a.shape[1]
- M, K = a.shape[2], a.shape[3]
- N = b.shape[3]
- total_batch = batch1 * batch2
-
- a_3d = a.reshape(total_batch, M, K)
- b_3d = b.reshape(total_batch, K, N)
- else:
- total_batch = a.shape[0]
- M, K = a.shape[1], a.shape[2]
- N = b.shape[2]
-
- a_3d = a
- b_3d = b
-
- # Allocate output
- if out is None:
- out_native = native.empty(list(out_shape), native.DataType.Float32)
- out = GPUArray._wrap_native(out_native)
-
- # Perform batched matmul via loop
- for i in range(total_batch):
- # Extract slice (creates view/copy depending on implementation)
- a_i = a_3d.to_numpy()[i]
- b_i = b_3d.to_numpy()[i]
-
- a_gpu = from_numpy(a_i)
- b_gpu = from_numpy(b_i)
-
- # Compute matmul for this batch element
- c_gpu = matmul(a_gpu, b_gpu)
-
- # Copy result to output
- out_np = out.to_numpy()
- if a.ndim == 4:
- i1, i2 = i // batch2, i % batch2
- out_np[i1, i2] = c_gpu.to_numpy()
- else:
- out_np[i] = c_gpu.to_numpy()
- out = from_numpy(out_np)
-
- return out
-
-
-def _batched_matmul_native(
- a: GPUArray,
- b: GPUArray,
- M: int,
- N: int,
- K: int,
- batch_count: int,
- out_shape: tuple[int, ...],
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Native cuBLASLt strided batched GEMM implementation."""
- from pygpukit.core.backend import get_native_module
- from pygpukit.core.dtypes import float32
-
- native = get_native_module()
-
- # Currently only FP32 supported via cuBLASLt strided batched
- if a.dtype != float32:
- warnings.warn(
- f"batched_matmul: GPU kernel requires float32, got {a.dtype}. Using CPU fallback (slow)",
- RuntimeWarning,
- stacklevel=3,
- )
- return _batched_matmul_cpu(a, b, out=out)
-
- # Compute strides for strided batched GEMM
- strideA = M * K
- strideB = K * N
- strideC = M * N
-
- # Get native arrays
- a_native = a._get_native()
- b_native = b._get_native()
-
- # Allocate output if needed (using native allocation)
- if out is None:
- out_native = native.empty(list(out_shape), native.DataType.Float32)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Call strided batched GEMM with CPU fallback for unsupported architectures
- try:
- native.gemm_strided_batched_fp32(
- a_native,
- b_native,
- out_native,
- M,
- N,
- K,
- batch_count,
- strideA,
- strideB,
- strideC,
- )
- except RuntimeError:
- # CUTLASS not available/failed (e.g., SM 120) - fall back to CPU
- warnings.warn(
- "batched_matmul: CUTLASS kernel failed, using CPU fallback (slow)",
- RuntimeWarning,
- stacklevel=3,
- )
- return _batched_matmul_cpu(a, b, out=out)
-
- return out
-
-
-def fp8_available() -> bool:
- """Check if FP8 GEMM is available (any backend).
-
- Returns:
- True if FP8 GEMM is available (requires SM90+ GPU).
- """
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- # Check all FP8 backends - return True if any is available
- return (
- native.gemm_fp8_f32_sm90_available()
- or native.gemm_fp8_f32_sm100_available()
- or native.gemm_fp8_f32_sm120_available()
- )
- else:
- return False
-
-
-# Alias for standardized naming
-gemm_fp8_available = fp8_available
-
-
-def fp8_sm90_available() -> bool:
- """Check if FP8 GEMM is available on SM90 (Hopper).
-
- Returns:
- True if FP8 GEMM is available (requires SM90+ and CUTLASS SM90 support).
- """
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- # Use new standardized name
- return native.gemm_fp8_f32_sm90_available()
- else:
- return False
-
-
-# Alias for standardized naming
-gemm_fp8_f32_sm90_available = fp8_sm90_available
-
-
-def fp8_sm100_available() -> bool:
- """Check if FP8 GEMM is available on SM100 (Blackwell datacenter).
-
- This may work on SM120 (Blackwell GeForce) as a fallback since both
- are Blackwell architecture.
-
- Returns:
- True if FP8 GEMM is available (requires SM100+ and CUTLASS SM100 support).
- """
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- # Use new standardized name
- return native.gemm_fp8_f32_sm100_available()
- else:
- return False
-
-
-# Alias for standardized naming
-gemm_fp8_f32_sm100_available = fp8_sm100_available
-
-
-def fp8_sm120_available() -> bool:
- """Check if FP8 GEMM is available on SM120 (Blackwell GeForce).
-
- Note: Currently disabled due to CUTLASS bug #2902.
-
- Returns:
- True if FP8 GEMM is available (requires SM120+ and CUTLASS SM120 support).
- """
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- # Use new standardized name
- return native.gemm_fp8_f32_sm120_available()
- else:
- return False
-
-
-# Alias for standardized naming
-gemm_fp8_f32_sm120_available = fp8_sm120_available
-
-
-def fp8_fp8_sm120_available() -> bool:
- """Check if Pure FP8 I/O GEMM is available on SM120 (Blackwell GeForce).
-
- This is for FP8 models where weights and activations are already in FP8 format.
-
- Returns:
- True if Pure FP8 GEMM is available (requires SM120+ and CUTLASS SM120 support).
- """
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- # Use new standardized name
- return native.gemm_fp8_fp8_sm120_available()
- else:
- return False
-
-
-# Alias for standardized naming
-gemm_fp8_fp8_sm120_available = fp8_fp8_sm120_available
-
-
-def matmul_fp8_fp8_sm120(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Pure FP8 I/O matrix multiplication for SM120 (Blackwell GeForce).
-
- This function takes FP8 E4M3 inputs directly (no conversion from FP32),
- performs the GEMM using CUTLASS FP8 kernels, and returns FP8 E4M3 output.
-
- This is optimized for FP8 models (Llama 3.1 FP8, etc.) where weights
- and activations are already quantized to FP8.
-
- Args:
- a: First input array (M x K), FP8 E4M3 stored as uint8.
- b: Second input array (K x N), FP8 E4M3 stored as uint8.
- Should be in ColumnMajor format (pre-transposed).
- out: Optional output array (M x N), uint8. If provided, result is
- written to this array instead of allocating a new one.
-
- Returns:
- The result GPUArray (M x N), FP8 E4M3 stored as uint8.
-
- Raises:
- ValueError: If arrays are not 2D, dtypes are not uint8, or dimensions don't match.
- RuntimeError: If FP8 SM120 is not available.
-
- Example:
- >>> import pygpukit as gk
- >>> # Assuming A and B are already FP8 quantized (stored as uint8)
- >>> A = gk.from_numpy(fp8_a_data) # [M, K] uint8
- >>> B = gk.from_numpy(fp8_b_data) # [K, N] uint8 (ColumnMajor)
- >>> C = gk.ops.matmul_fp8_fp8_sm120(A, B) # [M, N] uint8
- """
- from pygpukit.core.dtypes import uint8
-
- if a.ndim != 2:
- raise ValueError(
- f"matmul_fp8_fp8_sm120 requires 2D arrays, got {a.ndim}D for first argument"
- )
- if b.ndim != 2:
- raise ValueError(
- f"matmul_fp8_fp8_sm120 requires 2D arrays, got {b.ndim}D for second argument"
- )
-
- if a.shape[1] != b.shape[0]:
- raise ValueError(
- f"matmul_fp8_fp8_sm120 dimension mismatch: {a.shape} @ {b.shape} "
- f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)"
- )
-
- if a.dtype != uint8 or b.dtype != uint8:
- raise ValueError("matmul_fp8_fp8_sm120 requires uint8 inputs (FP8 E4M3)")
-
- if not fp8_fp8_sm120_available():
- raise RuntimeError("Pure FP8 SM120 GEMM is not available. Requires SM120+ GPU.")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _matmul_fp8_fp8_sm120_native(a, b, out=out)
- else:
- raise RuntimeError("Pure FP8 SM120 GEMM requires native backend")
-
-
-def _matmul_fp8_fp8_sm120_native(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Native C++ implementation of Pure FP8 I/O GEMM for SM120."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- # Get native arrays
- a_native = a._get_native()
- b_native = b._get_native()
-
- # Allocate output if needed
- if out is None:
- M, K = a.shape
- N = b.shape[1]
- out_native = native.empty([M, N], native.DataType.UInt8)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Call Pure FP8 GEMM (use new standardized name)
- native.gemm_fp8_fp8_sm120(a_native, b_native, out_native)
-
- return out
-
-
-# Alias for standardized naming
-gemm_fp8_fp8_sm120 = matmul_fp8_fp8_sm120
-
-
-def fp8_fp8_get_scale_sizes(M: int, N: int, K: int) -> tuple[int, int]:
- """Get scale factor sizes for FP8 blockwise GEMM.
-
- Returns the required sizes for scale_A and scale_B arrays for the
- given problem dimensions. These sizes depend on the internal tile
- configuration of the CUTLASS kernel.
-
- Args:
- M: Number of rows in A and output.
- N: Number of columns in B and output.
- K: Inner dimension (columns of A, rows of B).
-
- Returns:
- Tuple of (scale_A_size, scale_B_size) as integers.
-
- Example:
- >>> sfa_size, sfb_size = fp8_fp8_get_scale_sizes(256, 256, 256)
- >>> scale_A = pk.from_numpy(np.ones(sfa_size, dtype=np.float32))
- >>> scale_B = pk.from_numpy(np.ones(sfb_size, dtype=np.float32))
- """
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- # Use new standardized name
- return native.gemm_fp8_fp8_get_scale_sizes(M, N, K)
- else:
- return (0, 0)
-
-
-# Alias for standardized naming
-gemm_fp8_fp8_get_scale_sizes = fp8_fp8_get_scale_sizes
-
-
-def matmul_fp8_fp8_blockwise_sm120(
- a: GPUArray,
- b: GPUArray,
- scale_a: GPUArray,
- scale_b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Blockwise scaled FP8 I/O matrix multiplication for SM120.
-
- This function takes FP8 E4M3 inputs with per-block scale factors,
- performs the GEMM using CUTLASS FP8 kernels, and returns FP8 E4M3 output.
-
- The scale factors are applied per block during the GEMM computation,
- enabling better precision for FP8 models with varied value ranges.
-
- Args:
- a: First input array (M x K), FP8 E4M3 stored as uint8.
- b: Second input array (K x N), FP8 E4M3 stored as uint8.
- Should be in ColumnMajor format (pre-transposed).
- scale_a: Scale factors for A, float32. Size from fp8_fp8_get_scale_sizes().
- scale_b: Scale factors for B, float32. Size from fp8_fp8_get_scale_sizes().
- out: Optional output array (M x N), uint8. If provided, result is
- written to this array instead of allocating a new one.
-
- Returns:
- The result GPUArray (M x N), FP8 E4M3 stored as uint8.
-
- Raises:
- ValueError: If arrays are not 2D, dtypes are wrong, or dimensions don't match.
- RuntimeError: If FP8 SM120 is not available.
-
- Example:
- >>> import pygpukit as gk
- >>> from pygpukit.ops import fp8_fp8_get_scale_sizes, matmul_fp8_fp8_blockwise_sm120
- >>> M, N, K = 256, 256, 256
- >>> sfa_size, sfb_size = fp8_fp8_get_scale_sizes(M, N, K)
- >>> scale_A = gk.from_numpy(np.ones(sfa_size, dtype=np.float32))
- >>> scale_B = gk.from_numpy(np.ones(sfb_size, dtype=np.float32))
- >>> C = matmul_fp8_fp8_blockwise_sm120(A_fp8, B_fp8, scale_A, scale_B)
- """
- from pygpukit.core.dtypes import float32, uint8
-
- if a.ndim != 2:
- raise ValueError(f"matmul_fp8_fp8_blockwise_sm120 requires 2D arrays, got {a.ndim}D for A")
- if b.ndim != 2:
- raise ValueError(f"matmul_fp8_fp8_blockwise_sm120 requires 2D arrays, got {b.ndim}D for B")
-
- if a.shape[1] != b.shape[0]:
- raise ValueError(
- f"matmul_fp8_fp8_blockwise_sm120 dimension mismatch: {a.shape} @ {b.shape} "
- f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)"
- )
-
- if a.dtype != uint8 or b.dtype != uint8:
- raise ValueError("matmul_fp8_fp8_blockwise_sm120 requires uint8 inputs (FP8)")
-
- if scale_a.dtype != float32 or scale_b.dtype != float32:
- raise ValueError("matmul_fp8_fp8_blockwise_sm120 requires float32 scale factors")
-
- if not fp8_fp8_sm120_available():
- raise RuntimeError("FP8 blockwise SM120 GEMM is not available. Requires SM120+.")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _matmul_fp8_fp8_blockwise_sm120_native(a, b, scale_a, scale_b, out=out)
- else:
- raise RuntimeError("FP8 blockwise SM120 GEMM requires native backend")
-
-
-def _matmul_fp8_fp8_blockwise_sm120_native(
- a: GPUArray,
- b: GPUArray,
- scale_a: GPUArray,
- scale_b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Native C++ implementation of blockwise FP8 I/O GEMM for SM120."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- # Get native arrays
- a_native = a._get_native()
- b_native = b._get_native()
- scale_a_native = scale_a._get_native()
- scale_b_native = scale_b._get_native()
-
- # Allocate output if needed
- if out is None:
- M, K = a.shape
- N = b.shape[1]
- out_native = native.empty([M, N], native.DataType.UInt8)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Call blockwise FP8 GEMM
- native.gemm_fp8_fp8_blockwise_sm120(
- a_native, b_native, out_native, scale_a_native, scale_b_native
- )
-
- return out
-
-
-# Alias for standardized naming
-gemm_fp8_fp8_blockwise_sm120 = matmul_fp8_fp8_blockwise_sm120
-
-
-def matmul_fp8_sm100(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """FP8 matrix multiplication for SM100 (Blackwell datacenter).
-
- This function takes FP32 inputs, internally quantizes them to FP8,
- performs the GEMM using CUTLASS FP8 kernels with BF16 accumulation,
- and returns the result as FP32.
-
- This may work on SM120 (Blackwell GeForce) as a fallback since both
- are Blackwell architecture.
-
- Args:
- a: First input array (M x K), FP32.
- b: Second input array (K x N), FP32.
- out: Optional output array (M x N), FP32. If provided, result is
- written to this array instead of allocating a new one.
-
- Returns:
- The result GPUArray (M x N), FP32.
-
- Raises:
- ValueError: If arrays are not 2D, not FP32, or dimensions don't match.
- RuntimeError: If FP8 SM100 GEMM is not available or kernel fails.
-
- Example:
- >>> import pygpukit as gk
- >>> A = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1)
- >>> B = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1)
- >>> C = gk.ops.matmul_fp8_sm100(A, B)
- """
- from pygpukit.core.dtypes import float32
-
- if a.ndim != 2:
- raise ValueError(f"matmul_fp8_sm100 requires 2D arrays, got {a.ndim}D for first argument")
- if b.ndim != 2:
- raise ValueError(f"matmul_fp8_sm100 requires 2D arrays, got {b.ndim}D for second argument")
-
- if a.shape[1] != b.shape[0]:
- raise ValueError(
- f"matmul_fp8_sm100 dimension mismatch: {a.shape} @ {b.shape} "
- f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)"
- )
-
- if a.dtype != float32 or b.dtype != float32:
- raise ValueError("matmul_fp8_sm100 requires float32 inputs")
-
- if not fp8_sm100_available():
- raise RuntimeError(
- "FP8 SM100 GEMM is not available. Requires SM100+ GPU and CUTLASS SM100 support."
- )
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _matmul_fp8_sm100_native(a, b, out=out)
- else:
- raise RuntimeError("FP8 SM100 GEMM requires native backend")
-
-
-def _matmul_fp8_sm100_native(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Native C++ implementation of FP8 GEMM for SM100."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- # Get native arrays
- a_native = a._get_native()
- b_native = b._get_native()
-
- # Allocate output if needed
- if out is None:
- M, K = a.shape
- N = b.shape[1]
- out_native = native.empty([M, N], native.DataType.Float32)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Call FP8 GEMM (use new standardized name)
- native.gemm_fp8_f32_sm100(a_native, b_native, out_native)
-
- return out
-
-
-# Alias for standardized naming
-gemm_fp8_f32_sm100 = matmul_fp8_sm100
-
-
-def matmul_fp8_sm120(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """FP8 matrix multiplication for SM120 (Blackwell GeForce).
-
- This function takes FP32 inputs, internally quantizes them to FP8,
- performs the GEMM using CUTLASS FP8 kernels with BF16 accumulation,
- and returns the result as FP32.
-
- Args:
- a: First input array (M x K), FP32.
- b: Second input array (K x N), FP32.
- out: Optional output array (M x N), FP32. If provided, result is
- written to this array instead of allocating a new one.
-
- Returns:
- The result GPUArray (M x N), FP32.
-
- Raises:
- ValueError: If arrays are not 2D, not FP32, or dimensions don't match.
- RuntimeError: If FP8 SM120 GEMM is not available or kernel fails.
-
- Example:
- >>> import pygpukit as gk
- >>> A = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1)
- >>> B = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1)
- >>> C = gk.ops.matmul_fp8_sm120(A, B)
- """
- from pygpukit.core.dtypes import float32
-
- if a.ndim != 2:
- raise ValueError(f"matmul_fp8_sm120 requires 2D arrays, got {a.ndim}D for first argument")
- if b.ndim != 2:
- raise ValueError(f"matmul_fp8_sm120 requires 2D arrays, got {b.ndim}D for second argument")
-
- if a.shape[1] != b.shape[0]:
- raise ValueError(
- f"matmul_fp8_sm120 dimension mismatch: {a.shape} @ {b.shape} "
- f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)"
- )
-
- if a.dtype != float32 or b.dtype != float32:
- raise ValueError("matmul_fp8_sm120 requires float32 inputs")
-
- if not fp8_sm120_available():
- raise RuntimeError(
- "FP8 SM120 GEMM is not available. Requires SM120+ GPU and CUTLASS SM120 support."
- )
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _matmul_fp8_sm120_native(a, b, out=out)
- else:
- raise RuntimeError("FP8 SM120 GEMM requires native backend")
-
-
-def _matmul_fp8_sm120_native(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Native C++ implementation of FP8 GEMM for SM120."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- # Get native arrays
- a_native = a._get_native()
- b_native = b._get_native()
-
- # Allocate output if needed
- if out is None:
- M, K = a.shape
- N = b.shape[1]
- out_native = native.empty([M, N], native.DataType.Float32)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Call FP8 GEMM (use new standardized name)
- native.gemm_fp8_f32_sm120(a_native, b_native, out_native)
-
- return out
-
-
-# Alias for standardized naming
-gemm_fp8_f32_sm120 = matmul_fp8_sm120
-
-
-def matmul_fp8_sm90(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """FP8 matrix multiplication for SM90 (Hopper).
-
- This function takes FP32 inputs, internally quantizes them to FP8 with
- per-tensor scaling, performs the GEMM using CUTLASS FP8 kernels,
- and returns the result as FP32.
-
- Args:
- a: First input array (M x K), FP32.
- b: Second input array (K x N), FP32.
- out: Optional output array (M x N), FP32. If provided, result is
- written to this array instead of allocating a new one.
-
- Returns:
- The result GPUArray (M x N), FP32.
-
- Raises:
- ValueError: If arrays are not 2D, not FP32, or dimensions don't match.
- RuntimeError: If FP8 SM90 GEMM is not available or kernel fails.
-
- Example:
- >>> import pygpukit as gk
- >>> A = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1)
- >>> B = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1)
- >>> C = gk.ops.matmul_fp8_sm90(A, B)
- """
- from pygpukit.core.dtypes import float32
-
- if a.ndim != 2:
- raise ValueError(f"matmul_fp8_sm90 requires 2D arrays, got {a.ndim}D for first argument")
- if b.ndim != 2:
- raise ValueError(f"matmul_fp8_sm90 requires 2D arrays, got {b.ndim}D for second argument")
-
- if a.shape[1] != b.shape[0]:
- raise ValueError(
- f"matmul_fp8_sm90 dimension mismatch: {a.shape} @ {b.shape} "
- f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)"
- )
-
- if a.dtype != float32 or b.dtype != float32:
- raise ValueError("matmul_fp8_sm90 requires float32 inputs")
-
- if not fp8_sm90_available():
- raise RuntimeError(
- "FP8 SM90 GEMM is not available. Requires SM90+ GPU and CUTLASS SM90 support."
- )
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _matmul_fp8_sm90_native(a, b, out=out)
- else:
- raise RuntimeError("FP8 SM90 GEMM requires native backend")
-
-
-def _matmul_fp8_sm90_native(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Native C++ implementation of FP8 GEMM for SM90."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- # Get native arrays
- a_native = a._get_native()
- b_native = b._get_native()
-
- # Allocate output if needed
- if out is None:
- M, K = a.shape
- N = b.shape[1]
- out_native = native.empty([M, N], native.DataType.Float32)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Call FP8 GEMM (use new standardized name)
- native.gemm_fp8_f32_sm90(a_native, b_native, out_native)
-
- return out
-
-
-# Alias for standardized naming
-gemm_fp8_f32_sm90 = matmul_fp8_sm90
-
-
-def nvf4_bf16_sm120_available() -> bool:
- """Check if NVF4 (4-bit) BF16 GEMM is available on SM120 (Blackwell GeForce).
-
- This variant uses NVF4 (4-bit float) for 2x memory bandwidth compared to FP8,
- making it ideal for memory-bound LLM inference workloads.
-
- Returns:
- True if NVF4 BF16 SM120 GEMM is available, False otherwise.
- """
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- # Use new standardized name
- return native.gemm_nvf4_bf16_sm120_available()
- else:
- return False
-
-
-# Alias for standardized naming
-gemm_nvf4_bf16_sm120_available = nvf4_bf16_sm120_available
-
-
-def matmul_nvf4_bf16_sm120(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """NVF4 (4-bit) GEMM with BF16 input/output for SM120 (Blackwell GeForce).
-
- This variant uses NVF4 (float_e2m1_t, 4-bit) for the internal computation,
- providing 2x memory bandwidth compared to FP8. Ideal for memory-bound
- LLM inference workloads.
-
- Data flow: BF16 input -> NVF4 quantize with block scaling -> GEMM -> BF16 output
-
- Args:
- a: First input array (M x K), BF16.
- b: Second input array (K x N), BF16.
- out: Optional output array (M x N), BF16.
-
- Returns:
- The result GPUArray (M x N), BF16.
-
- Raises:
- ValueError: If arrays are not 2D, not BF16, or dimensions don't match.
- RuntimeError: If NVF4 BF16 SM120 GEMM is not available.
- """
- from pygpukit.core.dtypes import bfloat16
-
- if a.ndim != 2:
- raise ValueError(f"matmul_nvf4_bf16_sm120 requires 2D arrays, got {a.ndim}D")
- if b.ndim != 2:
- raise ValueError(f"matmul_nvf4_bf16_sm120 requires 2D arrays, got {b.ndim}D")
-
- if a.shape[1] != b.shape[0]:
- raise ValueError(f"matmul_nvf4_bf16_sm120 dimension mismatch: {a.shape} @ {b.shape}")
-
- if a.dtype != bfloat16 or b.dtype != bfloat16:
- raise ValueError("matmul_nvf4_bf16_sm120 requires bfloat16 inputs")
-
- if not nvf4_bf16_sm120_available():
- raise RuntimeError("NVF4 BF16 SM120 GEMM is not available. Requires SM120+ GPU.")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _matmul_nvf4_bf16_sm120_native(a, b, out=out)
- else:
- raise RuntimeError("NVF4 BF16 SM120 GEMM requires native backend")
-
-
-def _matmul_nvf4_bf16_sm120_native(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Native C++ implementation of NVF4 BF16 GEMM for SM120."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- # Get native arrays
- a_native = a._get_native()
- b_native = b._get_native()
-
- # Allocate output if needed
- if out is None:
- M, K = a.shape
- N = b.shape[1]
- out_native = native.empty([M, N], native.DataType.BFloat16)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Call NVF4 BF16 GEMM
- native.gemm_nvf4_bf16_sm120(a_native, b_native, out_native)
-
- return out
-
-
-# Alias for standardized naming
-gemm_nvf4_bf16_sm120 = matmul_nvf4_bf16_sm120
-
-
-# ============================================================================
-# GEMV Operations (M=1 special case)
-# ============================================================================
-
-
-def gemv_nvf4_available() -> bool:
- """Check if NVF4 GEMV is available (SM120+).
-
- Returns:
- True if NVF4 GEMV is available on current GPU.
- """
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- # Use new standardized name
- return native.gemv_nvf4_bf16_sm120_available()
- else:
- return False
-
-
-# Alias for standardized naming
-gemv_nvf4_bf16_sm120_available = gemv_nvf4_available
-
-
-def nvf4_get_sizes(K: int, N: int) -> tuple[int, int]:
- """Get buffer sizes for NVF4-quantized weights.
-
- Args:
- K: Inner dimension (input features).
- N: Output dimension (output features).
-
- Returns:
- Tuple of (data_size, scale_size) in bytes.
- - data_size: Size for packed NVF4 weights [K/2, N]
- - scale_size: Size for UE4M3 scale factors [K/32, N]
-
- Note:
- NVF4 provides 4x compression vs BF16:
- - BF16 weight size: K * N * 2 bytes
- - NVF4 total size: K/2 * N + K/32 * N bytes
- """
- data_size = (K // 2) * N
- scale_size = ((K + 31) // 32) * N
- return data_size, scale_size
-
-
-# Alias for standardized naming
-gemv_nvf4_get_sizes = nvf4_get_sizes
-
-
-def quantize_bf16_to_nvf4(
- input: GPUArray,
- out_data: GPUArray,
- out_scale: GPUArray,
-) -> None:
- """Quantize BF16 weights to NVF4 format with block scaling.
-
- This quantizes BF16 weights to 4-bit NVF4 format with UE4M3 scale factors.
- Each 32-element block shares one scale factor.
-
- Args:
- input: BF16 weight matrix [K, N].
- out_data: Pre-allocated buffer for packed NVF4 data [K/2, N] (uint8).
- out_scale: Pre-allocated buffer for scale factors [K/32, N] (uint8).
-
- Raises:
- ValueError: If input is not 2D BF16, or buffers have wrong size.
- RuntimeError: If NVF4 is not available.
-
- Note:
- NVF4 values: {0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0} and negatives.
- Block size: 32 elements per scale factor.
- """
- from pygpukit.core.dtypes import bfloat16
-
- if input.ndim != 2:
- raise ValueError(f"quantize_bf16_to_nvf4 requires 2D input, got {input.ndim}D")
-
- if input.dtype != bfloat16:
- raise ValueError(f"quantize_bf16_to_nvf4 requires bfloat16 input, got {input.dtype}")
-
- if not gemv_nvf4_available():
- raise RuntimeError("NVF4 quantization not available. Requires SM120+ GPU.")
-
- K, N = input.shape
- expected_data_size, expected_scale_size = nvf4_get_sizes(K, N)
-
- # Validate buffer sizes (count elements)
- actual_data_size = (
- out_data.shape[0] * out_data.shape[1] if out_data.ndim == 2 else out_data.size
- )
- actual_scale_size = (
- out_scale.shape[0] * out_scale.shape[1] if out_scale.ndim == 2 else out_scale.size
- )
-
- if actual_data_size < expected_data_size:
- raise ValueError(f"out_data buffer too small: {actual_data_size} < {expected_data_size}")
- if actual_scale_size < expected_scale_size:
- raise ValueError(f"out_scale buffer too small: {actual_scale_size} < {expected_scale_size}")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- input_native = input._get_native()
- data_native = out_data._get_native()
- scale_native = out_scale._get_native()
- native.quantize_bf16_to_nvf4(input_native, data_native, scale_native)
-
-
-def gemv_nvf4_bf16(
- a: GPUArray,
- b_data: GPUArray,
- b_scale: GPUArray,
- *,
- out: GPUArray | None = None,
- alpha: float = 1.0,
-) -> GPUArray:
- """NVF4 GEMV: C[N] = alpha * A[K] @ B[K,N] (NVF4 quantized).
-
- This performs matrix-vector multiplication where the weight matrix B
- is pre-quantized to NVF4 format with block scaling.
-
- Args:
- a: Input vector [K], BF16.
- b_data: Packed NVF4 weight data [K/2, N], uint8.
- b_scale: UE4M3 scale factors [K/32, N], uint8.
- out: Optional output vector [N], BF16.
- alpha: Scaling factor (default 1.0).
-
- Returns:
- Output vector [N], BF16.
-
- Raises:
- ValueError: If shapes or dtypes don't match.
- RuntimeError: If NVF4 GEMV is not available.
-
- Note:
- For LLM inference decode path (M=1), NVF4 provides 4x bandwidth
- reduction vs BF16, which is critical for memory-bound workloads.
- """
- from pygpukit.core.dtypes import bfloat16
-
- if a.ndim != 1:
- raise ValueError(f"gemv_nvf4_bf16 requires 1D input vector, got {a.ndim}D")
-
- if a.dtype != bfloat16:
- raise ValueError(f"gemv_nvf4_bf16 requires bfloat16 input, got {a.dtype}")
-
- if not gemv_nvf4_available():
- raise RuntimeError("NVF4 GEMV not available. Requires SM120+ GPU.")
-
- # Infer N from b_data shape: [K/2, N]
- if b_data.ndim == 2:
- N = b_data.shape[1]
- else:
- raise ValueError(f"b_data must be 2D [K/2, N], got {b_data.ndim}D")
-
- # Validate output
- if out is not None:
- if out.shape != (N,):
- raise ValueError(f"out shape {out.shape} does not match expected ({N},)")
- if out.dtype != bfloat16:
- raise ValueError(f"out dtype {out.dtype} must be bfloat16")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- a_native = a._get_native()
- data_native = b_data._get_native()
- scale_native = b_scale._get_native()
-
- if out is None:
- out_native = native.empty([N], native.DataType.BFloat16)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Use new standardized name
- native.gemv_nvf4_bf16_sm120(a_native, data_native, scale_native, out_native, alpha)
-
- return out
- else:
- raise RuntimeError("NVF4 GEMV requires native backend")
-
-
-# Alias for standardized naming
-gemv_nvf4_bf16_sm120 = gemv_nvf4_bf16
-
-
-def gemv_bf16(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """BF16 GEMV: C[N] = A[K] @ B[N,K]^T.
-
- Optimized BF16 matrix-vector multiplication with B[N,K] layout.
- Each row of B contains the weights for one output element.
-
- Args:
- a: Input vector [K], BF16.
- b: Weight matrix [N, K], BF16 (row-major, each row = one output).
- out: Optional output vector [N], BF16.
-
- Returns:
- Output vector [N], BF16.
-
- Raises:
- ValueError: If shapes or dtypes don't match.
-
- Note:
- This function uses the optimized B[N,K] layout for better memory
- coalescing. If you have weights in [K,N] format, transpose them first.
- """
- from pygpukit.core.dtypes import bfloat16
-
- if a.ndim != 1:
- raise ValueError(f"gemv_bf16 requires 1D input vector, got {a.ndim}D")
-
- if b.ndim != 2:
- raise ValueError(f"gemv_bf16 requires 2D weight matrix, got {b.ndim}D")
-
- if a.dtype != bfloat16 or b.dtype != bfloat16:
- raise ValueError("gemv_bf16 requires bfloat16 inputs")
-
- K = a.shape[0]
- N = b.shape[0] # N is first dim in [N, K] layout
-
- if b.shape[1] != K:
- raise ValueError(f"gemv_bf16 dimension mismatch: A[{K}] vs B[{N}, {b.shape[1]}]")
-
- # Validate output
- if out is not None:
- if out.shape != (N,):
- raise ValueError(f"out shape {out.shape} does not match expected ({N},)")
- if out.dtype != bfloat16:
- raise ValueError(f"out dtype {out.dtype} must be bfloat16")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- a_native = a._get_native()
- b_native = b._get_native()
-
- if out is None:
- out_native = native.empty([N], native.DataType.BFloat16)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Use optimized kernel with B[N,K] layout (new standardized name)
- native.gemv_bf16_bf16_sm120(a_native, b_native, out_native)
-
- return out
- else:
- # CPU fallback: B[N,K] @ A[K] = C[N] (B @ A^T transposed)
- a_np: np.ndarray[np.floating] = a.to_numpy().astype(np.float32)
- b_np: np.ndarray[np.floating] = b.to_numpy().astype(np.float32)
- result: np.ndarray[np.floating] = b_np @ a_np # [N,K] @ [K] = [N]
- if out is not None:
- result = result + out.to_numpy().astype(np.float32)
- return from_numpy(result.astype(np.float16).view(np.uint16).astype(np.uint16))
-
-
-# Alias for standardized naming
-gemv_bf16_bf16_sm120 = gemv_bf16
-
-
-# Flag to track if FP8 LUT has been initialized
-_FP8_LUT_INITIALIZED = False
-
-
-def fp8_init_lut() -> None:
- """Initialize FP8 E4M3 lookup table for dequantization.
-
- Note: LUT is defined as __device__ __constant__ in C++ and initialized
- at compile time, so this function is a no-op. Kept for API compatibility.
- """
- global _FP8_LUT_INITIALIZED
- if _FP8_LUT_INITIALIZED:
- return
- # LUT is already initialized in constant memory at compile time
- _FP8_LUT_INITIALIZED = True
-
-
-# Flag to track if W8A16 GEMM LUT has been initialized
-_W8A16_GEMM_LUT_INITIALIZED = False
-
-
-def w8a16_gemm_init_lut() -> None:
- """Initialize FP8->F32 LUT for W8A16 GEMM.
-
- This uses runtime initialization to avoid symbol conflicts with the GEMV LUT.
- Must be called before using w8a16_gemm_sm120.
- """
- global _W8A16_GEMM_LUT_INITIALIZED
- if _W8A16_GEMM_LUT_INITIALIZED:
- return
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- # Use new standardized name
- native.gemm_w8a16_init_lut()
- _W8A16_GEMM_LUT_INITIALIZED = True
-
-
-# Alias for standardized naming
-gemm_w8a16_init_lut = w8a16_gemm_init_lut
-
-
-def gemv_fp8_bf16(
- a: GPUArray,
- b_nk: GPUArray,
- b_scale: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Optimized FP8 GEMV: C[N] = A[K] @ B[N,K]^T.
-
- W8A16 GEMV: FP8 weights with BF16 activation and output.
- Uses warp-level reduction, shared memory, and vectorized loads.
-
- Args:
- a: Activation vector [K], BF16.
- b_nk: FP8 E4M3 weight matrix [N, K], stored as uint8.
- b_scale: Block-wise scale factors [N/128, K/128], BF16.
- out: Optional output vector [N], BF16.
-
- Returns:
- Output vector [N], BF16.
-
- Note:
- Weight layout is [N, K] (row = output dimension).
- Use original weight tensor directly (no transpose needed).
- """
- from pygpukit.core.dtypes import bfloat16, uint8
-
- if a.ndim != 1:
- raise ValueError(f"gemv_fp8_bf16 requires 1D input vector, got {a.ndim}D")
-
- if b_nk.ndim != 2:
- raise ValueError(f"gemv_fp8_bf16 requires 2D weight matrix, got {b_nk.ndim}D")
-
- if a.dtype != bfloat16:
- raise ValueError(f"gemv_fp8_bf16 requires bfloat16 activation, got {a.dtype}")
-
- if b_nk.dtype != uint8:
- raise ValueError(f"gemv_fp8_bf16 requires uint8 (FP8) weights, got {b_nk.dtype}")
-
- if b_scale.dtype != bfloat16:
- raise ValueError(f"gemv_fp8_bf16 requires bfloat16 scale, got {b_scale.dtype}")
-
- K = a.shape[0]
- N = b_nk.shape[0] # [N, K] layout
-
- if b_nk.shape[1] != K:
- raise ValueError(f"gemv_fp8_bf16 dimension mismatch: A[{K}] vs B[{N}, {b_nk.shape[1]}]")
-
- # Validate output
- if out is not None:
- if out.shape != (N,):
- raise ValueError(f"out shape {out.shape} does not match expected ({N},)")
- if out.dtype != bfloat16:
- raise ValueError(f"out dtype {out.dtype} must be bfloat16")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- a_native = a._get_native()
- b_nk_native = b_nk._get_native()
- b_scale_native = b_scale._get_native()
-
- if out is None:
- out_native = native.empty([N], native.DataType.BFloat16)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Use new standardized name
- native.gemv_fp8_bf16_sm120(a_native, b_nk_native, b_scale_native, out_native)
-
- return out
- else:
- raise NotImplementedError("FP8 GEMV requires native GPU backend")
-
-
-# Alias for standardized naming
-gemv_fp8_bf16_sm120 = gemv_fp8_bf16
-
-
-def gemv_fp8_bf16_batched(
- a: GPUArray,
- b_nk: GPUArray,
- b_scale: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Optimized batched FP8 GEMV: C[M,N] = A[M,K] @ B[N,K]^T.
-
- W8A16 GEMM for M>1: FP8 weights with BF16 activation and output.
- Uses warp-level reduction, shared memory, and vectorized loads.
-
- Args:
- a: Activation matrix [M, K], BF16.
- b_nk: FP8 E4M3 weight matrix [N, K], stored as uint8.
- b_scale: Block-wise scale factors [N/128, K/128], BF16.
- out: Optional output matrix [M, N], BF16.
-
- Returns:
- Output matrix [M, N], BF16.
-
- Note:
- Weight layout is [N, K] (row = output dimension).
- Use original weight tensor directly (no transpose needed).
- """
- from pygpukit.core.dtypes import bfloat16, uint8
-
- if a.ndim != 2:
- raise ValueError(f"gemv_fp8_bf16_batched requires 2D input matrix, got {a.ndim}D")
-
- if b_nk.ndim != 2:
- raise ValueError(f"gemv_fp8_bf16_batched requires 2D weight matrix, got {b_nk.ndim}D")
-
- if a.dtype != bfloat16:
- raise ValueError(f"gemv_fp8_bf16_batched requires bfloat16 activation, got {a.dtype}")
-
- if b_nk.dtype != uint8:
- raise ValueError(f"gemv_fp8_bf16_batched requires uint8 (FP8) weights, got {b_nk.dtype}")
-
- if b_scale.dtype != bfloat16:
- raise ValueError(f"gemv_fp8_bf16_batched requires bfloat16 scale, got {b_scale.dtype}")
-
- M = a.shape[0]
- K = a.shape[1]
- N = b_nk.shape[0] # [N, K] layout
-
- if b_nk.shape[1] != K:
- raise ValueError(
- f"gemv_fp8_bf16_batched dimension mismatch: A[{M},{K}] vs B[{N},{b_nk.shape[1]}]"
- )
-
- # Validate output
- if out is not None:
- if out.shape != (M, N):
- raise ValueError(f"out shape {out.shape} does not match expected ({M}, {N})")
- if out.dtype != bfloat16:
- raise ValueError(f"out dtype {out.dtype} must be bfloat16")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- a_native = a._get_native()
- b_nk_native = b_nk._get_native()
- b_scale_native = b_scale._get_native()
-
- if out is None:
- out_native = native.empty([M, N], native.DataType.BFloat16)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Use new standardized name
- native.gemv_fp8_bf16_batched_sm120(a_native, b_nk_native, b_scale_native, out_native)
-
- return out
- else:
- raise NotImplementedError("FP8 batched GEMV requires native GPU backend")
-
-
-# Alias for standardized naming
-gemv_fp8_bf16_batched_sm120 = gemv_fp8_bf16_batched
-
-
-def w8a16_gemm_sm120(
- a: GPUArray,
- b_fp8: GPUArray,
- b_scale: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """W8A16 GEMM for SM120: C[M,N] = A[M,K] @ dequant(B_fp8[K,N]).
-
- FP8 weight x BF16 activation -> BF16 output.
- Uses TensorCore GEMM with online FP8 dequantization.
- More efficient than batched GEMV for M > 1.
-
- Args:
- a: Activation matrix [M, K], BF16.
- b_fp8: FP8 E4M3 weight matrix [K, N], stored as uint8.
- b_scale: Block-wise scale factors [K/128, N/128], BF16.
- out: Optional output matrix [M, N], BF16.
-
- Returns:
- Output matrix [M, N], BF16.
- """
- from pygpukit.core.dtypes import bfloat16, uint8
-
- if a.ndim != 2:
- raise ValueError(f"w8a16_gemm_sm120 requires 2D input matrix, got {a.ndim}D")
-
- if b_fp8.ndim != 2:
- raise ValueError(f"w8a16_gemm_sm120 requires 2D weight matrix, got {b_fp8.ndim}D")
-
- if a.dtype != bfloat16:
- raise ValueError(f"w8a16_gemm_sm120 requires bfloat16 activation, got {a.dtype}")
-
- if b_fp8.dtype != uint8:
- raise ValueError(f"w8a16_gemm_sm120 requires uint8 (FP8) weights, got {b_fp8.dtype}")
-
- if b_scale.dtype != bfloat16:
- raise ValueError(f"w8a16_gemm_sm120 requires bfloat16 scale, got {b_scale.dtype}")
-
- M = a.shape[0]
- K = a.shape[1]
- if b_fp8.shape[0] != K:
- raise ValueError(
- f"w8a16_gemm_sm120 dimension mismatch: A[{M},{K}] vs B[{b_fp8.shape[0]}, {b_fp8.shape[1]}]"
- )
-
- N = b_fp8.shape[1]
-
- # Validate output
- if out is not None:
- if out.shape != (M, N):
- raise ValueError(f"out shape {out.shape} does not match expected ({M}, {N})")
- if out.dtype != bfloat16:
- raise ValueError(f"out dtype {out.dtype} must be bfloat16")
-
- # Initialize W8A16 GEMM LUT (runtime initialization to avoid symbol conflicts)
- w8a16_gemm_init_lut()
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- a_native = a._get_native()
- b_fp8_native = b_fp8._get_native()
- b_scale_native = b_scale._get_native()
-
- if out is None:
- out_native = native.empty([M, N], native.DataType.BFloat16)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Use new standardized name
- native.gemm_w8a16_bf16_sm120(a_native, b_fp8_native, b_scale_native, out_native)
-
- return out
- else:
- raise NotImplementedError("W8A16 GEMM requires native GPU backend with SM120")
-
-
-# Alias for standardized naming
-gemm_w8a16_bf16_sm120 = w8a16_gemm_sm120
-
-
-# Track if grouped GEMM LUT is initialized
-_grouped_gemm_lut_initialized = False
-
-
-def grouped_gemm_init_lut() -> None:
- """Initialize FP8->BF16 LUT for grouped GEMM.
-
- This must be called once before using grouped_gemm_fp8_bf16.
- """
- global _grouped_gemm_lut_initialized
- if _grouped_gemm_lut_initialized:
- return
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- native.grouped_gemm_init_lut()
- _grouped_gemm_lut_initialized = True
- else:
- raise NotImplementedError("Grouped GEMM requires native GPU backend")
-
-
-def grouped_gemm_fp8_bf16(
- a: GPUArray,
- b_stacked: GPUArray,
- b_scale: GPUArray,
- row_expert_ids: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Grouped GEMM for MoE: C = A @ B_stacked with per-row expert IDs.
-
- Each row has an associated expert ID, and the kernel dispatches to the
- correct expert's weights for each row.
-
- Args:
- a: Input tokens [M, K], BF16.
- b_stacked: Stacked expert weights [num_experts, N, K], FP8 (uint8).
- b_scale: Block-wise scales [num_experts, N/128, K/128], BF16.
- row_expert_ids: Expert ID for each row [M], int32.
- out: Optional output tensor [M, N], BF16.
-
- Returns:
- Output tensor [M, N], BF16.
- """
- from pygpukit.core.dtypes import bfloat16, int32, uint8
-
- if a.ndim != 2:
- raise ValueError(f"grouped_gemm_fp8_bf16 requires 2D input, got {a.ndim}D")
-
- if b_stacked.ndim != 3:
- raise ValueError(f"grouped_gemm_fp8_bf16 requires 3D weight, got {b_stacked.ndim}D")
-
- if a.dtype != bfloat16:
- raise ValueError(f"grouped_gemm_fp8_bf16 requires bfloat16 input, got {a.dtype}")
-
- if b_stacked.dtype != uint8:
- raise ValueError(
- f"grouped_gemm_fp8_bf16 requires uint8 (FP8) weights, got {b_stacked.dtype}"
- )
-
- if b_scale.dtype != bfloat16:
- raise ValueError(f"grouped_gemm_fp8_bf16 requires bfloat16 scale, got {b_scale.dtype}")
-
- if row_expert_ids.dtype != int32:
- raise ValueError(
- f"grouped_gemm_fp8_bf16 requires int32 row_expert_ids, got {row_expert_ids.dtype}"
- )
-
- M = a.shape[0]
- K = a.shape[1]
- N = b_stacked.shape[1]
-
- if b_stacked.shape[2] != K:
- raise ValueError(
- f"grouped_gemm_fp8_bf16: K mismatch A[{M},{K}] vs B[...{N},{b_stacked.shape[2]}]"
- )
-
- if row_expert_ids.shape[0] != M:
- raise ValueError(
- f"grouped_gemm_fp8_bf16: row_expert_ids size {row_expert_ids.shape[0]} != M ({M})"
- )
-
- # Validate output
- if out is not None:
- if out.shape != (M, N):
- raise ValueError(f"out shape {out.shape} does not match expected ({M}, {N})")
- if out.dtype != bfloat16:
- raise ValueError(f"out dtype {out.dtype} must be bfloat16")
-
- # Initialize LUT if not already done
- grouped_gemm_init_lut()
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- a_native = a._get_native()
- b_stacked_native = b_stacked._get_native()
- b_scale_native = b_scale._get_native()
- row_expert_ids_native = row_expert_ids._get_native()
-
- if out is None:
- out_native = native.empty([M, N], native.DataType.BFloat16)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Use new standardized name
- native.grouped_gemm_fp8_bf16_sm120(
- a_native, b_stacked_native, b_scale_native, out_native, row_expert_ids_native
- )
-
- return out
- else:
- raise NotImplementedError("Grouped GEMM requires native GPU backend")
-
-
-# Alias for standardized naming
-grouped_gemm_fp8_bf16_sm120 = grouped_gemm_fp8_bf16
-
-
-def fp8_get_sizes(K: int, N: int) -> tuple[int, int, int]:
- """Get scale tensor dimensions for FP8 block quantization.
-
- Args:
- K: Input dimension.
- N: Output dimension.
-
- Returns:
- (scale_K, scale_N, scale_size_bytes): Scale tensor dimensions
- for 128x128 block quantization.
- """
- scale_k = (K + 127) // 128
- scale_n = (N + 127) // 128
- scale_size = scale_k * scale_n * 2 # BF16 = 2 bytes
- return scale_k, scale_n, scale_size
-
-
-# ============================================================================
-# FP8 Operations
-# ============================================================================
-
-
-def matmul_fp8(
- a: GPUArray,
- b: GPUArray,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """FP8 matrix multiplication with automatic backend selection.
-
- This function takes FP32 inputs, internally quantizes them to FP8,
- performs the GEMM using the best available CUTLASS FP8 kernel,
- and returns the result as FP32.
-
- Backend priority:
- - SM120 (Blackwell GeForce): blockwise scaling (when CUTLASS bug #2902 is fixed)
- - SM90 (Hopper): per-tensor scaling
-
- Args:
- a: First input array (M x K), FP32.
- b: Second input array (K x N), FP32.
- out: Optional output array (M x N), FP32. If provided, result is
- written to this array instead of allocating a new one.
-
- Returns:
- The result GPUArray (M x N), FP32.
-
- Raises:
- ValueError: If arrays are not 2D, not FP32, or dimensions don't match.
- RuntimeError: If no FP8 GEMM backend is available.
-
- Example:
- >>> import pygpukit as gk
- >>> A = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1)
- >>> B = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1)
- >>> C = gk.ops.matmul_fp8(A, B)
- """
- from pygpukit.core.dtypes import float32
-
- if a.ndim != 2:
- raise ValueError(f"matmul_fp8 requires 2D arrays, got {a.ndim}D for first argument")
- if b.ndim != 2:
- raise ValueError(f"matmul_fp8 requires 2D arrays, got {b.ndim}D for second argument")
-
- if a.shape[1] != b.shape[0]:
- raise ValueError(
- f"matmul_fp8 dimension mismatch: {a.shape} @ {b.shape} "
- f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)"
- )
-
- if a.dtype != float32 or b.dtype != float32:
- raise ValueError("matmul_fp8 requires float32 inputs")
-
- if not fp8_available():
- raise RuntimeError("FP8 GEMM is not available. Requires SM90+ GPU and CUTLASS support.")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
-
- # Get native arrays
- a_native = a._get_native()
- b_native = b._get_native()
-
- # Allocate output if needed
- if out is None:
- M, K = a.shape
- N = b.shape[1]
- out_native = native.empty([M, N], native.DataType.Float32)
- out = GPUArray._wrap_native(out_native)
- else:
- out_native = out._get_native()
-
- # Call auto-dispatch FP8 GEMM
- native.gemm_fp8(a_native, b_native, out_native)
-
- return out
- else:
- raise RuntimeError("FP8 GEMM requires native backend")
diff --git a/src/pygpukit/ops/matmul/__init__.py b/src/pygpukit/ops/matmul/__init__.py
new file mode 100644
index 0000000..a1ed162
--- /dev/null
+++ b/src/pygpukit/ops/matmul/__init__.py
@@ -0,0 +1,155 @@
+"""Matrix multiplication operations for GPUArrays.
+
+This module provides various GEMM (General Matrix Multiply) and GEMV
+(General Matrix-Vector) operations optimized for different GPU architectures
+and data types.
+
+Corresponds to native/ops/matmul/.
+"""
+
+from __future__ import annotations
+
+# Availability checks
+from .availability import (
+ fp8_available,
+ fp8_fp8_sm120_available,
+ fp8_sm90_available,
+ fp8_sm100_available,
+ fp8_sm120_available,
+ gemm_fp8_available,
+ gemm_fp8_f32_sm90_available,
+ gemm_fp8_f32_sm100_available,
+ gemm_fp8_f32_sm120_available,
+ gemm_fp8_fp8_sm120_available,
+ gemm_nvf4_bf16_sm120_available,
+ gemv_nvf4_available,
+ gemv_nvf4_bf16_sm120_available,
+ nvf4_bf16_sm120_available,
+)
+
+# FP8 GEMM operations
+from .fp8 import (
+ fp8_fp8_get_scale_sizes,
+ fp8_get_sizes,
+ fp8_init_lut,
+ gemm_fp8_f32_sm90,
+ gemm_fp8_f32_sm100,
+ gemm_fp8_f32_sm120,
+ gemm_fp8_fp8_blockwise_sm120,
+ gemm_fp8_fp8_get_scale_sizes,
+ gemm_fp8_fp8_sm120,
+ matmul_fp8,
+ matmul_fp8_fp8_blockwise_sm120,
+ matmul_fp8_fp8_sm120,
+ matmul_fp8_sm90,
+ matmul_fp8_sm100,
+ matmul_fp8_sm120,
+)
+
+# GEMV operations
+from .gemv import (
+ gemv_bf16,
+ gemv_bf16_bf16_sm120,
+ gemv_fp8_bf16,
+ gemv_fp8_bf16_batched,
+ gemv_fp8_bf16_batched_sm120,
+ gemv_fp8_bf16_sm120,
+)
+
+# Generic matmul operations
+from .generic import (
+ batched_matmul,
+ linear_bias_gelu,
+ matmul,
+ transpose,
+)
+
+# Grouped GEMM for MoE
+from .grouped import (
+ grouped_gemm_fp8_bf16,
+ grouped_gemm_fp8_bf16_sm120,
+ grouped_gemm_init_lut,
+)
+
+# NVF4 (4-bit) operations
+from .nvf4 import (
+ gemm_nvf4_bf16_sm120,
+ gemv_nvf4_bf16,
+ gemv_nvf4_bf16_sm120,
+ gemv_nvf4_get_sizes,
+ matmul_nvf4_bf16_sm120,
+ nvf4_get_sizes,
+ quantize_bf16_to_nvf4,
+)
+
+# W8A16 GEMM operations
+from .w8a16 import (
+ gemm_w8a16_bf16_sm120,
+ gemm_w8a16_init_lut,
+ w8a16_gemm_init_lut,
+ w8a16_gemm_sm120,
+)
+
+__all__ = [
+ # Generic operations
+ "matmul",
+ "batched_matmul",
+ "transpose",
+ "linear_bias_gelu",
+ # Availability checks
+ "fp8_available",
+ "gemm_fp8_available",
+ "fp8_sm90_available",
+ "gemm_fp8_f32_sm90_available",
+ "fp8_sm100_available",
+ "gemm_fp8_f32_sm100_available",
+ "fp8_sm120_available",
+ "gemm_fp8_f32_sm120_available",
+ "fp8_fp8_sm120_available",
+ "gemm_fp8_fp8_sm120_available",
+ "nvf4_bf16_sm120_available",
+ "gemm_nvf4_bf16_sm120_available",
+ "gemv_nvf4_available",
+ "gemv_nvf4_bf16_sm120_available",
+ # FP8 GEMM operations
+ "matmul_fp8",
+ "matmul_fp8_sm90",
+ "matmul_fp8_sm100",
+ "matmul_fp8_sm120",
+ "matmul_fp8_fp8_sm120",
+ "matmul_fp8_fp8_blockwise_sm120",
+ "fp8_fp8_get_scale_sizes",
+ "fp8_get_sizes",
+ "fp8_init_lut",
+ # FP8 aliases
+ "gemm_fp8_f32_sm90",
+ "gemm_fp8_f32_sm100",
+ "gemm_fp8_f32_sm120",
+ "gemm_fp8_fp8_sm120",
+ "gemm_fp8_fp8_blockwise_sm120",
+ "gemm_fp8_fp8_get_scale_sizes",
+ # NVF4 (4-bit) operations
+ "nvf4_get_sizes",
+ "gemv_nvf4_get_sizes",
+ "quantize_bf16_to_nvf4",
+ "matmul_nvf4_bf16_sm120",
+ "gemm_nvf4_bf16_sm120",
+ "gemv_nvf4_bf16",
+ "gemv_nvf4_bf16_sm120",
+ # GEMV operations
+ "gemv_bf16",
+ "gemv_bf16_bf16_sm120",
+ "gemv_fp8_bf16",
+ "gemv_fp8_bf16_sm120",
+ "gemv_fp8_bf16_batched",
+ "gemv_fp8_bf16_batched_sm120",
+ # W8A16 GEMM operations
+ "w8a16_gemm_init_lut",
+ "gemm_w8a16_init_lut",
+ "w8a16_gemm_sm120",
+ "gemm_w8a16_bf16_sm120",
+ # Grouped GEMM (MoE)
+ "grouped_gemm_init_lut",
+ "grouped_gemm_fp8_bf16",
+ "grouped_gemm_fp8_bf16_sm120",
+]
diff --git a/src/pygpukit/ops/matmul/availability.py b/src/pygpukit/ops/matmul/availability.py
new file mode 100644
index 0000000..cb6cfbb
--- /dev/null
+++ b/src/pygpukit/ops/matmul/availability.py
@@ -0,0 +1,128 @@
+"""Availability check functions for GEMM/GEMV operations.
+
+All *_available() functions to check GPU capability.
+"""
+
+from __future__ import annotations
+
+from pygpukit.core.backend import NativeBackend, get_backend
+
+
+def fp8_available() -> bool:
+ """Check if FP8 GEMM is available (any backend)."""
+ backend = get_backend()
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ return (
+ native.gemm_fp8_f32_sm90_available()
+ or native.gemm_fp8_f32_sm100_available()
+ or native.gemm_fp8_f32_sm120_available()
+ )
+ return False
+
+
+gemm_fp8_available = fp8_available
+
+
+def fp8_sm90_available() -> bool:
+ """Check if FP8 GEMM is available on SM90 (Hopper)."""
+ backend = get_backend()
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ return native.gemm_fp8_f32_sm90_available()
+ return False
+
+
+gemm_fp8_f32_sm90_available = fp8_sm90_available
+
+
+def fp8_sm100_available() -> bool:
+ """Check if FP8 GEMM is available on SM100 (Blackwell datacenter)."""
+ backend = get_backend()
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ return native.gemm_fp8_f32_sm100_available()
+ return False
+
+
+gemm_fp8_f32_sm100_available = fp8_sm100_available
+
+
+def fp8_sm120_available() -> bool:
+ """Check if FP8 GEMM is available on SM120 (Blackwell GeForce)."""
+ backend = get_backend()
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ return native.gemm_fp8_f32_sm120_available()
+ return False
+
+
+gemm_fp8_f32_sm120_available = fp8_sm120_available
+
+
+def fp8_fp8_sm120_available() -> bool:
+ """Check if Pure FP8 I/O GEMM is available on SM120."""
+ backend = get_backend()
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ return native.gemm_fp8_fp8_sm120_available()
+ return False
+
+
+gemm_fp8_fp8_sm120_available = fp8_fp8_sm120_available
+
+
+def nvf4_bf16_sm120_available() -> bool:
+ """Check if NVF4 (4-bit) BF16 GEMM is available on SM120."""
+ backend = get_backend()
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ return native.gemm_nvf4_bf16_sm120_available()
+ return False
+
+
+gemm_nvf4_bf16_sm120_available = nvf4_bf16_sm120_available
+
+
+def gemv_nvf4_available() -> bool:
+ """Check if NVF4 GEMV is available (SM120+)."""
+ backend = get_backend()
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ return native.gemv_nvf4_bf16_sm120_available()
+ return False
+
+
+gemv_nvf4_bf16_sm120_available = gemv_nvf4_available
+
+
+__all__ = [
+ "fp8_available",
+ "gemm_fp8_available",
+ "fp8_sm90_available",
+ "gemm_fp8_f32_sm90_available",
+ "fp8_sm100_available",
+ "gemm_fp8_f32_sm100_available",
+ "fp8_sm120_available",
+ "gemm_fp8_f32_sm120_available",
+ "fp8_fp8_sm120_available",
+ "gemm_fp8_fp8_sm120_available",
+ "nvf4_bf16_sm120_available",
+ "gemm_nvf4_bf16_sm120_available",
+ "gemv_nvf4_available",
+ "gemv_nvf4_bf16_sm120_available",
+]
diff --git a/src/pygpukit/ops/matmul/fp8.py b/src/pygpukit/ops/matmul/fp8.py
new file mode 100644
index 0000000..dbf70ea
--- /dev/null
+++ b/src/pygpukit/ops/matmul/fp8.py
@@ -0,0 +1,383 @@
+"""FP8 GEMM operations.
+
+FP8 matrix multiplication for SM90/SM100/SM120.
+"""
+
+from __future__ import annotations
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.backend import NativeBackend, get_backend
+
+from .availability import (
+ fp8_available,
+ fp8_fp8_sm120_available,
+ fp8_sm90_available,
+ fp8_sm100_available,
+ fp8_sm120_available,
+)
+
+
+def matmul_fp8(
+ a: GPUArray,
+ b: GPUArray,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """FP8 matrix multiplication with automatic backend selection.
+
+ Takes FP32 inputs, internally quantizes to FP8, performs GEMM,
+ and returns FP32 result.
+ """
+ from pygpukit.core.dtypes import float32
+
+ if a.ndim != 2:
+ raise ValueError(f"matmul_fp8 requires 2D arrays, got {a.ndim}D for first argument")
+ if b.ndim != 2:
+ raise ValueError(f"matmul_fp8 requires 2D arrays, got {b.ndim}D for second argument")
+
+ if a.shape[1] != b.shape[0]:
+ raise ValueError(
+ f"matmul_fp8 dimension mismatch: {a.shape} @ {b.shape} "
+ f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)"
+ )
+
+ if a.dtype != float32 or b.dtype != float32:
+ raise ValueError("matmul_fp8 requires float32 inputs")
+
+ if not fp8_available():
+ raise RuntimeError("FP8 GEMM is not available. Requires SM90+ GPU and CUTLASS support.")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ b_native = b._get_native()
+
+ if out is None:
+ M, K = a.shape
+ N = b.shape[1]
+ out_native = native.empty([M, N], native.DataType.Float32)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ native.gemm_fp8(a_native, b_native, out_native)
+ return out
+ else:
+ raise RuntimeError("FP8 GEMM requires native backend")
+
+
+def matmul_fp8_sm90(
+ a: GPUArray,
+ b: GPUArray,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """FP8 matrix multiplication for SM90 (Hopper)."""
+ from pygpukit.core.dtypes import float32
+
+ if a.ndim != 2:
+ raise ValueError(f"matmul_fp8_sm90 requires 2D arrays, got {a.ndim}D for first argument")
+ if b.ndim != 2:
+ raise ValueError(f"matmul_fp8_sm90 requires 2D arrays, got {b.ndim}D for second argument")
+
+ if a.shape[1] != b.shape[0]:
+ raise ValueError(f"matmul_fp8_sm90 dimension mismatch: {a.shape} @ {b.shape}")
+
+ if a.dtype != float32 or b.dtype != float32:
+ raise ValueError("matmul_fp8_sm90 requires float32 inputs")
+
+ if not fp8_sm90_available():
+ raise RuntimeError("FP8 SM90 GEMM is not available.")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ b_native = b._get_native()
+
+ if out is None:
+ M, K = a.shape
+ N = b.shape[1]
+ out_native = native.empty([M, N], native.DataType.Float32)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ native.gemm_fp8_f32_sm90(a_native, b_native, out_native)
+ return out
+ else:
+ raise RuntimeError("FP8 SM90 GEMM requires native backend")
+
+
+gemm_fp8_f32_sm90 = matmul_fp8_sm90
+
+
+def matmul_fp8_sm100(
+ a: GPUArray,
+ b: GPUArray,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """FP8 matrix multiplication for SM100 (Blackwell datacenter)."""
+ from pygpukit.core.dtypes import float32
+
+ if a.ndim != 2:
+ raise ValueError(f"matmul_fp8_sm100 requires 2D arrays, got {a.ndim}D")
+ if b.ndim != 2:
+ raise ValueError(f"matmul_fp8_sm100 requires 2D arrays, got {b.ndim}D")
+
+ if a.shape[1] != b.shape[0]:
+ raise ValueError(f"matmul_fp8_sm100 dimension mismatch: {a.shape} @ {b.shape}")
+
+ if a.dtype != float32 or b.dtype != float32:
+ raise ValueError("matmul_fp8_sm100 requires float32 inputs")
+
+ if not fp8_sm100_available():
+ raise RuntimeError("FP8 SM100 GEMM is not available.")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ b_native = b._get_native()
+
+ if out is None:
+ M, K = a.shape
+ N = b.shape[1]
+ out_native = native.empty([M, N], native.DataType.Float32)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ native.gemm_fp8_f32_sm100(a_native, b_native, out_native)
+ return out
+ else:
+ raise RuntimeError("FP8 SM100 GEMM requires native backend")
+
+
+gemm_fp8_f32_sm100 = matmul_fp8_sm100
+
+
+def matmul_fp8_sm120(
+ a: GPUArray,
+ b: GPUArray,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """FP8 matrix multiplication for SM120 (Blackwell GeForce)."""
+ from pygpukit.core.dtypes import float32
+
+ if a.ndim != 2:
+ raise ValueError(f"matmul_fp8_sm120 requires 2D arrays, got {a.ndim}D")
+ if b.ndim != 2:
+ raise ValueError(f"matmul_fp8_sm120 requires 2D arrays, got {b.ndim}D")
+
+ if a.shape[1] != b.shape[0]:
+ raise ValueError(f"matmul_fp8_sm120 dimension mismatch: {a.shape} @ {b.shape}")
+
+ if a.dtype != float32 or b.dtype != float32:
+ raise ValueError("matmul_fp8_sm120 requires float32 inputs")
+
+ if not fp8_sm120_available():
+ raise RuntimeError("FP8 SM120 GEMM is not available.")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ b_native = b._get_native()
+
+ if out is None:
+ M, K = a.shape
+ N = b.shape[1]
+ out_native = native.empty([M, N], native.DataType.Float32)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ native.gemm_fp8_f32_sm120(a_native, b_native, out_native)
+ return out
+ else:
+ raise RuntimeError("FP8 SM120 GEMM requires native backend")
+
+
+gemm_fp8_f32_sm120 = matmul_fp8_sm120
+
+
+def matmul_fp8_fp8_sm120(
+ a: GPUArray,
+ b: GPUArray,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """Pure FP8 I/O matrix multiplication for SM120 (Blackwell GeForce).
+
+ Takes FP8 E4M3 inputs directly (no conversion from FP32).
+ """
+ from pygpukit.core.dtypes import uint8
+
+ if a.ndim != 2:
+ raise ValueError(f"matmul_fp8_fp8_sm120 requires 2D arrays, got {a.ndim}D")
+ if b.ndim != 2:
+ raise ValueError(f"matmul_fp8_fp8_sm120 requires 2D arrays, got {b.ndim}D")
+
+ if a.shape[1] != b.shape[0]:
+ raise ValueError(f"matmul_fp8_fp8_sm120 dimension mismatch: {a.shape} @ {b.shape}")
+
+ if a.dtype != uint8 or b.dtype != uint8:
+ raise ValueError("matmul_fp8_fp8_sm120 requires uint8 inputs (FP8 E4M3)")
+
+ if not fp8_fp8_sm120_available():
+ raise RuntimeError("Pure FP8 SM120 GEMM is not available.")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ b_native = b._get_native()
+
+ if out is None:
+ M, K = a.shape
+ N = b.shape[1]
+ out_native = native.empty([M, N], native.DataType.UInt8)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ native.gemm_fp8_fp8_sm120(a_native, b_native, out_native)
+ return out
+ else:
+ raise RuntimeError("Pure FP8 SM120 GEMM requires native backend")
+
+
+gemm_fp8_fp8_sm120 = matmul_fp8_fp8_sm120
+
+
+def fp8_fp8_get_scale_sizes(M: int, N: int, K: int) -> tuple[int, int]:
+ """Get scale factor sizes for FP8 blockwise GEMM."""
+ backend = get_backend()
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ return native.gemm_fp8_fp8_get_scale_sizes(M, N, K)
+ return (0, 0)
+
+
+gemm_fp8_fp8_get_scale_sizes = fp8_fp8_get_scale_sizes
+
+
+def matmul_fp8_fp8_blockwise_sm120(
+ a: GPUArray,
+ b: GPUArray,
+ scale_a: GPUArray,
+ scale_b: GPUArray,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """Blockwise scaled FP8 I/O matrix multiplication for SM120."""
+ from pygpukit.core.dtypes import float32, uint8
+
+ if a.ndim != 2:
+ raise ValueError(f"matmul_fp8_fp8_blockwise_sm120 requires 2D arrays, got {a.ndim}D")
+ if b.ndim != 2:
+ raise ValueError(f"matmul_fp8_fp8_blockwise_sm120 requires 2D arrays, got {b.ndim}D")
+
+ if a.shape[1] != b.shape[0]:
+ raise ValueError(
+ f"matmul_fp8_fp8_blockwise_sm120 dimension mismatch: {a.shape} @ {b.shape}"
+ )
+
+ if a.dtype != uint8 or b.dtype != uint8:
+ raise ValueError("matmul_fp8_fp8_blockwise_sm120 requires uint8 inputs (FP8)")
+
+ if scale_a.dtype != float32 or scale_b.dtype != float32:
+ raise ValueError("matmul_fp8_fp8_blockwise_sm120 requires float32 scale factors")
+
+ if not fp8_fp8_sm120_available():
+ raise RuntimeError("FP8 blockwise SM120 GEMM is not available.")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ b_native = b._get_native()
+ scale_a_native = scale_a._get_native()
+ scale_b_native = scale_b._get_native()
+
+ if out is None:
+ M, K = a.shape
+ N = b.shape[1]
+ out_native = native.empty([M, N], native.DataType.UInt8)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ native.gemm_fp8_fp8_blockwise_sm120(
+ a_native, b_native, out_native, scale_a_native, scale_b_native
+ )
+ return out
+ else:
+ raise RuntimeError("FP8 blockwise SM120 GEMM requires native backend")
+
+
+gemm_fp8_fp8_blockwise_sm120 = matmul_fp8_fp8_blockwise_sm120
+
+
+def fp8_get_sizes(K: int, N: int) -> tuple[int, int, int]:
+ """Get scale tensor dimensions for FP8 block quantization."""
+ scale_k = (K + 127) // 128
+ scale_n = (N + 127) // 128
+ scale_size = scale_k * scale_n * 2
+ return scale_k, scale_n, scale_size
+
+
+# LUT initialization
+_FP8_LUT_INITIALIZED = False
+
+
+def fp8_init_lut() -> None:
+ """Initialize FP8 E4M3 lookup table for dequantization."""
+ global _FP8_LUT_INITIALIZED
+ if _FP8_LUT_INITIALIZED:
+ return
+ _FP8_LUT_INITIALIZED = True
+
+
+__all__ = [
+ "matmul_fp8",
+ "matmul_fp8_sm90",
+ "matmul_fp8_sm100",
+ "matmul_fp8_sm120",
+ "matmul_fp8_fp8_sm120",
+ "matmul_fp8_fp8_blockwise_sm120",
+ "fp8_fp8_get_scale_sizes",
+ "fp8_get_sizes",
+ "fp8_init_lut",
+ # Aliases
+ "gemm_fp8_f32_sm90",
+ "gemm_fp8_f32_sm100",
+ "gemm_fp8_f32_sm120",
+ "gemm_fp8_fp8_sm120",
+ "gemm_fp8_fp8_blockwise_sm120",
+ "gemm_fp8_fp8_get_scale_sizes",
+]
diff --git a/src/pygpukit/ops/matmul/gemv.py b/src/pygpukit/ops/matmul/gemv.py
new file mode 100644
index 0000000..9ca8df3
--- /dev/null
+++ b/src/pygpukit/ops/matmul/gemv.py
@@ -0,0 +1,205 @@
+"""GEMV (Matrix-Vector) operations.
+
+Optimized GEMV for LLM decode (M=1 case).
+"""
+
+from __future__ import annotations
+
+import numpy as np
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.backend import NativeBackend, get_backend
+from pygpukit.core.factory import from_numpy
+
+
+def gemv_bf16(
+ a: GPUArray,
+ b: GPUArray,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """BF16 GEMV: C[N] = A[K] @ B[N,K]^T.
+
+ Optimized BF16 matrix-vector multiplication with B[N,K] layout.
+ """
+ from pygpukit.core.dtypes import bfloat16
+
+ if a.ndim != 1:
+ raise ValueError(f"gemv_bf16 requires 1D input vector, got {a.ndim}D")
+ if b.ndim != 2:
+ raise ValueError(f"gemv_bf16 requires 2D weight matrix, got {b.ndim}D")
+ if a.dtype != bfloat16 or b.dtype != bfloat16:
+ raise ValueError("gemv_bf16 requires bfloat16 inputs")
+
+ K = a.shape[0]
+ N = b.shape[0]
+
+ if b.shape[1] != K:
+ raise ValueError(f"gemv_bf16 dimension mismatch: A[{K}] vs B[{N}, {b.shape[1]}]")
+
+ if out is not None:
+ if out.shape != (N,):
+ raise ValueError(f"out shape {out.shape} does not match expected ({N},)")
+ if out.dtype != bfloat16:
+ raise ValueError(f"out dtype {out.dtype} must be bfloat16")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ b_native = b._get_native()
+
+ if out is None:
+ out_native = native.empty([N], native.DataType.BFloat16)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ native.gemv_bf16_bf16_sm120(a_native, b_native, out_native)
+ return out
+ else:
+ a_np: np.ndarray = a.to_numpy().astype(np.float32)
+ b_np: np.ndarray = b.to_numpy().astype(np.float32)
+ result: np.ndarray = b_np @ a_np
+ return from_numpy(result.astype(np.float16).view(np.uint16).astype(np.uint16))
+
+
+gemv_bf16_bf16_sm120 = gemv_bf16
+
+
+def gemv_fp8_bf16(
+ a: GPUArray,
+ b_nk: GPUArray,
+ b_scale: GPUArray,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """Optimized FP8 GEMV: C[N] = A[K] @ B[N,K]^T.
+
+ W8A16 GEMV: FP8 weights with BF16 activation and output.
+ """
+ from pygpukit.core.dtypes import bfloat16, uint8
+
+ if a.ndim != 1:
+ raise ValueError(f"gemv_fp8_bf16 requires 1D input vector, got {a.ndim}D")
+ if b_nk.ndim != 2:
+ raise ValueError(f"gemv_fp8_bf16 requires 2D weight matrix, got {b_nk.ndim}D")
+ if a.dtype != bfloat16:
+ raise ValueError(f"gemv_fp8_bf16 requires bfloat16 activation, got {a.dtype}")
+ if b_nk.dtype != uint8:
+ raise ValueError(f"gemv_fp8_bf16 requires uint8 (FP8) weights, got {b_nk.dtype}")
+ if b_scale.dtype != bfloat16:
+ raise ValueError(f"gemv_fp8_bf16 requires bfloat16 scale, got {b_scale.dtype}")
+
+ K = a.shape[0]
+ N = b_nk.shape[0]
+
+ if b_nk.shape[1] != K:
+ raise ValueError(f"gemv_fp8_bf16 dimension mismatch: A[{K}] vs B[{N}, {b_nk.shape[1]}]")
+
+ if out is not None:
+ if out.shape != (N,):
+ raise ValueError(f"out shape {out.shape} does not match expected ({N},)")
+ if out.dtype != bfloat16:
+ raise ValueError(f"out dtype {out.dtype} must be bfloat16")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ b_nk_native = b_nk._get_native()
+ b_scale_native = b_scale._get_native()
+
+ if out is None:
+ out_native = native.empty([N], native.DataType.BFloat16)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ native.gemv_fp8_bf16_sm120(a_native, b_nk_native, b_scale_native, out_native)
+ return out
+ else:
+ raise NotImplementedError("FP8 GEMV requires native GPU backend")
+
+
+gemv_fp8_bf16_sm120 = gemv_fp8_bf16
+
+
+def gemv_fp8_bf16_batched(
+ a: GPUArray,
+ b_nk: GPUArray,
+ b_scale: GPUArray,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """Optimized batched FP8 GEMV: C[M,N] = A[M,K] @ B[N,K]^T.
+
+ W8A16 GEMM for M>1: FP8 weights with BF16 activation and output.
+ """
+ from pygpukit.core.dtypes import bfloat16, uint8
+
+ if a.ndim != 2:
+ raise ValueError(f"gemv_fp8_bf16_batched requires 2D input matrix, got {a.ndim}D")
+ if b_nk.ndim != 2:
+ raise ValueError(f"gemv_fp8_bf16_batched requires 2D weight matrix, got {b_nk.ndim}D")
+ if a.dtype != bfloat16:
+ raise ValueError(f"gemv_fp8_bf16_batched requires bfloat16 activation, got {a.dtype}")
+ if b_nk.dtype != uint8:
+ raise ValueError(f"gemv_fp8_bf16_batched requires uint8 (FP8) weights, got {b_nk.dtype}")
+ if b_scale.dtype != bfloat16:
+ raise ValueError(f"gemv_fp8_bf16_batched requires bfloat16 scale, got {b_scale.dtype}")
+
+ M = a.shape[0]
+ K = a.shape[1]
+ N = b_nk.shape[0]
+
+ if b_nk.shape[1] != K:
+ raise ValueError(
+ f"gemv_fp8_bf16_batched dimension mismatch: A[{M},{K}] vs B[{N},{b_nk.shape[1]}]"
+ )
+
+ if out is not None:
+ if out.shape != (M, N):
+ raise ValueError(f"out shape {out.shape} does not match expected ({M}, {N})")
+ if out.dtype != bfloat16:
+ raise ValueError(f"out dtype {out.dtype} must be bfloat16")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ b_nk_native = b_nk._get_native()
+ b_scale_native = b_scale._get_native()
+
+ if out is None:
+ out_native = native.empty([M, N], native.DataType.BFloat16)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ native.gemv_fp8_bf16_batched_sm120(a_native, b_nk_native, b_scale_native, out_native)
+ return out
+ else:
+ raise NotImplementedError("FP8 batched GEMV requires native GPU backend")
+
+
+gemv_fp8_bf16_batched_sm120 = gemv_fp8_bf16_batched
+
+
+__all__ = [
+ "gemv_bf16",
+ "gemv_bf16_bf16_sm120",
+ "gemv_fp8_bf16",
+ "gemv_fp8_bf16_sm120",
+ "gemv_fp8_bf16_batched",
+ "gemv_fp8_bf16_batched_sm120",
+]
diff --git a/src/pygpukit/ops/matmul/generic.py b/src/pygpukit/ops/matmul/generic.py
new file mode 100644
index 0000000..0291794
--- /dev/null
+++ b/src/pygpukit/ops/matmul/generic.py
@@ -0,0 +1,384 @@
+"""Generic matrix multiplication operations.
+
+Basic matmul, batched_matmul, transpose, and linear_bias_gelu.
+"""
+
+from __future__ import annotations
+
+import warnings
+
+import numpy as np
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.backend import NativeBackend, get_backend
+from pygpukit.core.factory import from_numpy
+from pygpukit.ops._common import _validate_float_dtype, _validate_same_dtype
+
+
+def matmul(
+ a: GPUArray,
+ b: GPUArray,
+ *,
+ out: GPUArray | None = None,
+ use_tf32: bool | None = None,
+) -> GPUArray:
+ """Matrix multiplication of two 2D arrays.
+
+ Args:
+ a: First input array (M x K).
+ b: Second input array (K x N).
+ out: Optional output array (M x N). If provided, result is written to this
+ array instead of allocating a new one. This enables CUDA Graph capture
+ since no memory allocation occurs during the operation.
+ use_tf32: Whether to use TF32 TensorCore acceleration (Ampere+ only).
+ - None (default): Use PYGPUKIT_ALLOW_TF32 environment variable
+ - True: Force TF32 mode (requires SM >= 80 and float32)
+ - False: Force FP32 mode
+
+ Returns:
+ The result GPUArray (M x N). If out is provided, returns out.
+
+ Raises:
+ ValueError: If arrays are not 2D or dimensions don't match.
+ RuntimeError: If use_tf32=True but GPU doesn't support it or dtype is not float32.
+ """
+ if a.ndim != 2:
+ raise ValueError(f"matmul requires 2D arrays, got {a.ndim}D for first argument")
+ if b.ndim != 2:
+ raise ValueError(f"matmul requires 2D arrays, got {b.ndim}D for second argument")
+
+ if a.shape[1] != b.shape[0]:
+ raise ValueError(
+ f"matmul dimension mismatch: {a.shape} @ {b.shape} "
+ f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)"
+ )
+
+ _validate_same_dtype(a, b, "matmul")
+
+ if out is not None:
+ expected_shape = (a.shape[0], b.shape[1])
+ if out.shape != expected_shape:
+ raise ValueError(f"out shape {out.shape} does not match expected {expected_shape}")
+ if out.dtype != a.dtype:
+ raise ValueError(f"out dtype {out.dtype} does not match input dtype {a.dtype}")
+
+ if use_tf32 is True:
+ from pygpukit.core.dtypes import float32
+
+ if a.dtype != float32:
+ raise RuntimeError("TF32 matmul requires float32 dtype")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ return _matmul_native(a, b, out=out, use_tf32=use_tf32)
+ else:
+ return _matmul_cpu(a, b, out=out)
+
+
+def _matmul_cpu(a: GPUArray, b: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
+ """CPU implementation of matmul."""
+ a_np = a.to_numpy()
+ b_np = b.to_numpy()
+ if out is not None:
+ out_np = out.to_numpy()
+ np.matmul(a_np, b_np, out=out_np)
+ out._data = from_numpy(out_np)._data
+ return out
+ else:
+ result_np = np.matmul(a_np, b_np)
+ return from_numpy(result_np)
+
+
+def _matmul_native(
+ a: GPUArray,
+ b: GPUArray,
+ *,
+ out: GPUArray | None = None,
+ use_tf32: bool | None = None,
+) -> GPUArray:
+ """Native C++ CUDA implementation of matmul (zero-copy)."""
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ b_native = b._get_native()
+
+ if out is not None:
+ out_native = out._get_native()
+ if use_tf32 is not None:
+ native.matmul_tf32_(a_native, b_native, out_native, use_tf32)
+ else:
+ native.matmul_(a_native, b_native, out_native)
+ return out
+ else:
+ if use_tf32 is not None:
+ c_native = native.matmul_tf32(a_native, b_native, use_tf32)
+ else:
+ c_native = native.matmul(a_native, b_native)
+ return GPUArray._wrap_native(c_native)
+
+
+def transpose(a: GPUArray) -> GPUArray:
+ """Matrix transpose.
+
+ Args:
+ a: Input array of shape [rows, cols].
+
+ Returns:
+ A new GPUArray of shape [cols, rows] containing a.T.
+ """
+ if a.ndim != 2:
+ raise ValueError(f"transpose expects 2D input [rows, cols], got {a.ndim}D")
+
+ from pygpukit.core.dtypes import uint8
+
+ backend = get_backend()
+
+ if a.dtype == uint8:
+ return _transpose_cpu(a)
+
+ _validate_float_dtype(a, "transpose")
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ return _transpose_native(a)
+ else:
+ return _transpose_cpu(a)
+
+
+def _transpose_cpu(a: GPUArray) -> GPUArray:
+ """CPU implementation of transpose."""
+ a_np = a.to_numpy()
+ return from_numpy(a_np.T.copy())
+
+
+def _transpose_native(a: GPUArray) -> GPUArray:
+ """Native C++ CUDA implementation of transpose (zero-copy)."""
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ c_native = native.transpose(a_native)
+ return GPUArray._wrap_native(c_native)
+
+
+def linear_bias_gelu(
+ input: GPUArray,
+ weight: GPUArray,
+ bias: GPUArray,
+) -> GPUArray:
+ """Fused linear + bias + GELU operation.
+
+ Computes: output = gelu(input @ weight^T + bias)
+ """
+ _validate_float_dtype(input, "linear_bias_gelu")
+
+ if input.ndim != 2:
+ raise ValueError(
+ f"linear_bias_gelu expects 2D input [batch, in_features], got {input.ndim}D"
+ )
+ if weight.ndim != 2:
+ raise ValueError(
+ f"linear_bias_gelu expects 2D weight [out_features, in_features], got {weight.ndim}D"
+ )
+ if bias.ndim != 1:
+ raise ValueError(f"linear_bias_gelu expects 1D bias [out_features], got {bias.ndim}D")
+
+ if input.dtype != weight.dtype or input.dtype != bias.dtype:
+ raise ValueError("linear_bias_gelu: all inputs must have same dtype")
+
+ in_features = input.shape[1]
+ out_features = weight.shape[0]
+
+ if weight.shape[1] != in_features:
+ raise ValueError(
+ f"linear_bias_gelu: weight.shape[1]={weight.shape[1]} must match "
+ f"input.shape[1]={in_features}"
+ )
+ if bias.shape[0] != out_features:
+ raise ValueError(
+ f"linear_bias_gelu: bias.shape[0]={bias.shape[0]} must match "
+ f"weight.shape[0]={out_features}"
+ )
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ return _linear_bias_gelu_native(input, weight, bias)
+ else:
+ return _linear_bias_gelu_cpu(input, weight, bias)
+
+
+def _linear_bias_gelu_cpu(
+ input: GPUArray,
+ weight: GPUArray,
+ bias: GPUArray,
+) -> GPUArray:
+ """CPU implementation of linear_bias_gelu."""
+ x = input.to_numpy()
+ w = weight.to_numpy()
+ b = bias.to_numpy()
+ y = x @ w.T + b
+ sqrt_2_over_pi = np.sqrt(2.0 / np.pi)
+ result = y * 0.5 * (1.0 + np.tanh(sqrt_2_over_pi * (y + 0.044715 * y**3)))
+ return from_numpy(result.astype(x.dtype))
+
+
+def _linear_bias_gelu_native(
+ input: GPUArray,
+ weight: GPUArray,
+ bias: GPUArray,
+) -> GPUArray:
+ """Native C++ CUDA implementation of linear_bias_gelu (CUTLASS fused kernel)."""
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ input_native = input._get_native()
+ weight_native = weight._get_native()
+ bias_native = bias._get_native()
+ c_native = native.linear_bias_gelu(input_native, weight_native, bias_native)
+ return GPUArray._wrap_native(c_native)
+
+
+def batched_matmul(
+ a: GPUArray,
+ b: GPUArray,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """Batched matrix multiplication for 3D and 4D tensors.
+
+ Supports:
+ - 3D: [batch, M, K] @ [batch, K, N] -> [batch, M, N]
+ - 4D: [batch1, batch2, M, K] @ [batch1, batch2, K, N] -> [batch1, batch2, M, N]
+ """
+ if a.ndim not in (3, 4):
+ raise ValueError(f"batched_matmul requires 3D or 4D arrays, got {a.ndim}D")
+ if b.ndim not in (3, 4):
+ raise ValueError(f"batched_matmul requires 3D or 4D arrays, got {b.ndim}D")
+ if a.ndim != b.ndim:
+ raise ValueError(f"batched_matmul requires same ndim, got {a.ndim}D and {b.ndim}D")
+
+ _validate_same_dtype(a, b, "batched_matmul")
+
+ if a.ndim == 3:
+ batch = a.shape[0]
+ M, K = a.shape[1], a.shape[2]
+ K2, N = b.shape[1], b.shape[2]
+ if b.shape[0] != batch:
+ raise ValueError(f"Batch dimension mismatch: {a.shape[0]} vs {b.shape[0]}")
+ if K != K2:
+ raise ValueError(f"Inner dimension mismatch: {K} vs {K2}")
+ out_shape = (batch, M, N)
+ batch_count = batch
+ else:
+ batch1, batch2 = a.shape[0], a.shape[1]
+ M, K = a.shape[2], a.shape[3]
+ K2, N = b.shape[2], b.shape[3]
+ if b.shape[0] != batch1 or b.shape[1] != batch2:
+ raise ValueError(
+ f"Batch dimensions mismatch: ({batch1}, {batch2}) vs ({b.shape[0]}, {b.shape[1]})"
+ )
+ if K != K2:
+ raise ValueError(f"Inner dimension mismatch: {K} vs {K2}")
+ out_shape = (batch1, batch2, M, N)
+ batch_count = batch1 * batch2
+
+ if out is not None:
+ if out.shape != out_shape:
+ raise ValueError(f"out shape {out.shape} does not match expected {out_shape}")
+ if out.dtype != a.dtype:
+ raise ValueError(f"out dtype {out.dtype} does not match input dtype {a.dtype}")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ return _batched_matmul_native(a, b, M, N, K, batch_count, out_shape, out=out)
+ else:
+ return _batched_matmul_cpu(a, b, out=out)
+
+
+def _batched_matmul_cpu(a: GPUArray, b: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
+ """CPU implementation of batched_matmul."""
+ a_np = a.to_numpy()
+ b_np = b.to_numpy()
+ result_np = np.matmul(a_np, b_np)
+ result = from_numpy(result_np)
+
+ if out is not None:
+ from pygpukit.ops.elementwise import copy_to
+
+ copy_to(result, out)
+ return out
+ else:
+ return result
+
+
+def _batched_matmul_native(
+ a: GPUArray,
+ b: GPUArray,
+ M: int,
+ N: int,
+ K: int,
+ batch_count: int,
+ out_shape: tuple[int, ...],
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """Native cuBLASLt strided batched GEMM implementation."""
+ from pygpukit.core.backend import get_native_module
+ from pygpukit.core.dtypes import float32
+
+ native = get_native_module()
+
+ if a.dtype != float32:
+ warnings.warn(
+ f"batched_matmul: GPU kernel requires float32, got {a.dtype}. Using CPU fallback (slow)",
+ RuntimeWarning,
+ stacklevel=3,
+ )
+ return _batched_matmul_cpu(a, b, out=out)
+
+ strideA = M * K
+ strideB = K * N
+ strideC = M * N
+
+ a_native = a._get_native()
+ b_native = b._get_native()
+
+ if out is None:
+ out_native = native.empty(list(out_shape), native.DataType.Float32)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ try:
+ native.gemm_strided_batched_fp32(
+ a_native,
+ b_native,
+ out_native,
+ M,
+ N,
+ K,
+ batch_count,
+ strideA,
+ strideB,
+ strideC,
+ )
+ except RuntimeError:
+ warnings.warn(
+ "batched_matmul: CUTLASS kernel failed, using CPU fallback (slow)",
+ RuntimeWarning,
+ stacklevel=3,
+ )
+ return _batched_matmul_cpu(a, b, out=out)
+
+ return out
+
+
+__all__ = [
+ "matmul",
+ "transpose",
+ "linear_bias_gelu",
+ "batched_matmul",
+]
diff --git a/src/pygpukit/ops/matmul/grouped.py b/src/pygpukit/ops/matmul/grouped.py
new file mode 100644
index 0000000..73a4a2d
--- /dev/null
+++ b/src/pygpukit/ops/matmul/grouped.py
@@ -0,0 +1,141 @@
+"""Grouped GEMM operations for MoE (Mixture of Experts).
+
+Grouped GEMM with per-row expert dispatching.
+"""
+
+from __future__ import annotations
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.backend import NativeBackend, get_backend
+
+# Track if grouped GEMM LUT is initialized
+_grouped_gemm_lut_initialized = False
+
+
+def grouped_gemm_init_lut() -> None:
+ """Initialize FP8->BF16 LUT for grouped GEMM.
+
+ This must be called once before using grouped_gemm_fp8_bf16.
+ """
+ global _grouped_gemm_lut_initialized
+ if _grouped_gemm_lut_initialized:
+ return
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ native.grouped_gemm_init_lut()
+ _grouped_gemm_lut_initialized = True
+ else:
+ raise NotImplementedError("Grouped GEMM requires native GPU backend")
+
+
+def grouped_gemm_fp8_bf16(
+ a: GPUArray,
+ b_stacked: GPUArray,
+ b_scale: GPUArray,
+ row_expert_ids: GPUArray,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """Grouped GEMM for MoE: C = A @ B_stacked with per-row expert IDs.
+
+ Each row has an associated expert ID, and the kernel dispatches to the
+ correct expert's weights for each row.
+
+ Args:
+ a: Input tokens [M, K], BF16.
+ b_stacked: Stacked expert weights [num_experts, N, K], FP8 (uint8).
+ b_scale: Block-wise scales [num_experts, N/128, K/128], BF16.
+ row_expert_ids: Expert ID for each row [M], int32.
+ out: Optional output tensor [M, N], BF16.
+
+ Returns:
+ Output tensor [M, N], BF16.
+ """
+ from pygpukit.core.dtypes import bfloat16, int32, uint8
+
+ if a.ndim != 2:
+ raise ValueError(f"grouped_gemm_fp8_bf16 requires 2D input, got {a.ndim}D")
+
+ if b_stacked.ndim != 3:
+ raise ValueError(f"grouped_gemm_fp8_bf16 requires 3D weight, got {b_stacked.ndim}D")
+
+ if a.dtype != bfloat16:
+ raise ValueError(f"grouped_gemm_fp8_bf16 requires bfloat16 input, got {a.dtype}")
+
+ if b_stacked.dtype != uint8:
+ raise ValueError(
+ f"grouped_gemm_fp8_bf16 requires uint8 (FP8) weights, got {b_stacked.dtype}"
+ )
+
+ if b_scale.dtype != bfloat16:
+ raise ValueError(f"grouped_gemm_fp8_bf16 requires bfloat16 scale, got {b_scale.dtype}")
+
+ if row_expert_ids.dtype != int32:
+ raise ValueError(
+ f"grouped_gemm_fp8_bf16 requires int32 row_expert_ids, got {row_expert_ids.dtype}"
+ )
+
+ M = a.shape[0]
+ K = a.shape[1]
+ N = b_stacked.shape[1]
+
+ if b_stacked.shape[2] != K:
+ raise ValueError(
+ f"grouped_gemm_fp8_bf16: K mismatch A[{M},{K}] vs B[...{N},{b_stacked.shape[2]}]"
+ )
+
+ if row_expert_ids.shape[0] != M:
+ raise ValueError(
+ f"grouped_gemm_fp8_bf16: row_expert_ids size {row_expert_ids.shape[0]} != M ({M})"
+ )
+
+ # Validate output
+ if out is not None:
+ if out.shape != (M, N):
+ raise ValueError(f"out shape {out.shape} does not match expected ({M}, {N})")
+ if out.dtype != bfloat16:
+ raise ValueError(f"out dtype {out.dtype} must be bfloat16")
+
+ # Initialize LUT if not already done
+ grouped_gemm_init_lut()
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+
+ a_native = a._get_native()
+ b_stacked_native = b_stacked._get_native()
+ b_scale_native = b_scale._get_native()
+ row_expert_ids_native = row_expert_ids._get_native()
+
+ if out is None:
+ out_native = native.empty([M, N], native.DataType.BFloat16)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ native.grouped_gemm_fp8_bf16_sm120(
+ a_native, b_stacked_native, b_scale_native, out_native, row_expert_ids_native
+ )
+
+ return out
+ else:
+ raise NotImplementedError("Grouped GEMM requires native GPU backend")
+
+
+grouped_gemm_fp8_bf16_sm120 = grouped_gemm_fp8_bf16
+
+
+__all__ = [
+ "grouped_gemm_init_lut",
+ "grouped_gemm_fp8_bf16",
+ "grouped_gemm_fp8_bf16_sm120",
+]
diff --git a/src/pygpukit/ops/matmul/nvf4.py b/src/pygpukit/ops/matmul/nvf4.py
new file mode 100644
index 0000000..57c0efb
--- /dev/null
+++ b/src/pygpukit/ops/matmul/nvf4.py
@@ -0,0 +1,205 @@
+"""NVF4 (4-bit float) operations.
+
+NVF4 provides 4x memory bandwidth compared to BF16.
+"""
+
+from __future__ import annotations
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.backend import NativeBackend, get_backend
+
+from .availability import gemv_nvf4_available, nvf4_bf16_sm120_available
+
+
+def nvf4_get_sizes(K: int, N: int) -> tuple[int, int]:
+ """Get buffer sizes for NVF4-quantized weights.
+
+ Args:
+ K: Inner dimension (input features).
+ N: Output dimension (output features).
+
+ Returns:
+ Tuple of (data_size, scale_size) in bytes.
+ """
+ data_size = (K // 2) * N
+ scale_size = ((K + 31) // 32) * N
+ return data_size, scale_size
+
+
+gemv_nvf4_get_sizes = nvf4_get_sizes
+
+
+def quantize_bf16_to_nvf4(
+ input: GPUArray,
+ out_data: GPUArray,
+ out_scale: GPUArray,
+) -> None:
+ """Quantize BF16 weights to NVF4 format with block scaling.
+
+ Args:
+ input: BF16 weight matrix [K, N].
+ out_data: Pre-allocated buffer for packed NVF4 data [K/2, N] (uint8).
+ out_scale: Pre-allocated buffer for scale factors [K/32, N] (uint8).
+ """
+ from pygpukit.core.dtypes import bfloat16
+
+ if input.ndim != 2:
+ raise ValueError(f"quantize_bf16_to_nvf4 requires 2D input, got {input.ndim}D")
+ if input.dtype != bfloat16:
+ raise ValueError(f"quantize_bf16_to_nvf4 requires bfloat16 input, got {input.dtype}")
+ if not gemv_nvf4_available():
+ raise RuntimeError("NVF4 quantization not available. Requires SM120+ GPU.")
+
+ K, N = input.shape
+ expected_data_size, expected_scale_size = nvf4_get_sizes(K, N)
+
+ actual_data_size = (
+ out_data.shape[0] * out_data.shape[1] if out_data.ndim == 2 else out_data.size
+ )
+ actual_scale_size = (
+ out_scale.shape[0] * out_scale.shape[1] if out_scale.ndim == 2 else out_scale.size
+ )
+
+ if actual_data_size < expected_data_size:
+ raise ValueError(f"out_data buffer too small: {actual_data_size} < {expected_data_size}")
+ if actual_scale_size < expected_scale_size:
+ raise ValueError(f"out_scale buffer too small: {actual_scale_size} < {expected_scale_size}")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ input_native = input._get_native()
+ data_native = out_data._get_native()
+ scale_native = out_scale._get_native()
+ native.quantize_bf16_to_nvf4(input_native, data_native, scale_native)
+
+
+def matmul_nvf4_bf16_sm120(
+ a: GPUArray,
+ b: GPUArray,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """NVF4 (4-bit) GEMM with BF16 input/output for SM120.
+
+ Data flow: BF16 input -> NVF4 quantize with block scaling -> GEMM -> BF16 output
+ """
+ from pygpukit.core.dtypes import bfloat16
+
+ if a.ndim != 2:
+ raise ValueError(f"matmul_nvf4_bf16_sm120 requires 2D arrays, got {a.ndim}D")
+ if b.ndim != 2:
+ raise ValueError(f"matmul_nvf4_bf16_sm120 requires 2D arrays, got {b.ndim}D")
+
+ if a.shape[1] != b.shape[0]:
+ raise ValueError(f"matmul_nvf4_bf16_sm120 dimension mismatch: {a.shape} @ {b.shape}")
+
+ if a.dtype != bfloat16 or b.dtype != bfloat16:
+ raise ValueError("matmul_nvf4_bf16_sm120 requires bfloat16 inputs")
+
+ if not nvf4_bf16_sm120_available():
+ raise RuntimeError("NVF4 BF16 SM120 GEMM is not available. Requires SM120+ GPU.")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ b_native = b._get_native()
+
+ if out is None:
+ M, K = a.shape
+ N = b.shape[1]
+ out_native = native.empty([M, N], native.DataType.BFloat16)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ native.gemm_nvf4_bf16_sm120(a_native, b_native, out_native)
+ return out
+ else:
+ raise RuntimeError("NVF4 BF16 SM120 GEMM requires native backend")
+
+
+gemm_nvf4_bf16_sm120 = matmul_nvf4_bf16_sm120
+
+
+def gemv_nvf4_bf16(
+ a: GPUArray,
+ b_data: GPUArray,
+ b_scale: GPUArray,
+ *,
+ out: GPUArray | None = None,
+ alpha: float = 1.0,
+) -> GPUArray:
+ """NVF4 GEMV: C[N] = alpha * A[K] @ B[K,N] (NVF4 quantized).
+
+ Args:
+ a: Input vector [K], BF16.
+ b_data: Packed NVF4 weight data [K/2, N], uint8.
+ b_scale: UE4M3 scale factors [K/32, N], uint8.
+ out: Optional output vector [N], BF16.
+ alpha: Scaling factor (default 1.0).
+
+ Returns:
+ Output vector [N], BF16.
+ """
+ from pygpukit.core.dtypes import bfloat16
+
+ if a.ndim != 1:
+ raise ValueError(f"gemv_nvf4_bf16 requires 1D input vector, got {a.ndim}D")
+ if a.dtype != bfloat16:
+ raise ValueError(f"gemv_nvf4_bf16 requires bfloat16 input, got {a.dtype}")
+ if not gemv_nvf4_available():
+ raise RuntimeError("NVF4 GEMV not available. Requires SM120+ GPU.")
+
+ if b_data.ndim == 2:
+ N = b_data.shape[1]
+ else:
+ raise ValueError(f"b_data must be 2D [K/2, N], got {b_data.ndim}D")
+
+ if out is not None:
+ if out.shape != (N,):
+ raise ValueError(f"out shape {out.shape} does not match expected ({N},)")
+ if out.dtype != bfloat16:
+ raise ValueError(f"out dtype {out.dtype} must be bfloat16")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ data_native = b_data._get_native()
+ scale_native = b_scale._get_native()
+
+ if out is None:
+ out_native = native.empty([N], native.DataType.BFloat16)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ native.gemv_nvf4_bf16_sm120(a_native, data_native, scale_native, out_native, alpha)
+ return out
+ else:
+ raise RuntimeError("NVF4 GEMV requires native backend")
+
+
+gemv_nvf4_bf16_sm120 = gemv_nvf4_bf16
+
+
+__all__ = [
+ "nvf4_get_sizes",
+ "gemv_nvf4_get_sizes",
+ "quantize_bf16_to_nvf4",
+ "matmul_nvf4_bf16_sm120",
+ "gemm_nvf4_bf16_sm120",
+ "gemv_nvf4_bf16",
+ "gemv_nvf4_bf16_sm120",
+]
diff --git a/src/pygpukit/ops/matmul/w8a16.py b/src/pygpukit/ops/matmul/w8a16.py
new file mode 100644
index 0000000..99214b9
--- /dev/null
+++ b/src/pygpukit/ops/matmul/w8a16.py
@@ -0,0 +1,128 @@
+"""W8A16 GEMM operations.
+
+Weight 8-bit (FP8), Activation 16-bit (BF16) GEMM.
+"""
+
+from __future__ import annotations
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.backend import NativeBackend, get_backend
+
+# Flag to track if W8A16 GEMM LUT has been initialized
+_W8A16_GEMM_LUT_INITIALIZED = False
+
+
+def w8a16_gemm_init_lut() -> None:
+ """Initialize FP8->F32 LUT for W8A16 GEMM.
+
+ This uses runtime initialization to avoid symbol conflicts with the GEMV LUT.
+ Must be called before using w8a16_gemm_sm120.
+ """
+ global _W8A16_GEMM_LUT_INITIALIZED
+ if _W8A16_GEMM_LUT_INITIALIZED:
+ return
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ native.gemm_w8a16_init_lut()
+ _W8A16_GEMM_LUT_INITIALIZED = True
+
+
+gemm_w8a16_init_lut = w8a16_gemm_init_lut
+
+
+def w8a16_gemm_sm120(
+ a: GPUArray,
+ b_fp8: GPUArray,
+ b_scale: GPUArray,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """W8A16 GEMM for SM120: C[M,N] = A[M,K] @ dequant(B_fp8[K,N]).
+
+ FP8 weight x BF16 activation -> BF16 output.
+ Uses TensorCore GEMM with online FP8 dequantization.
+ More efficient than batched GEMV for M > 1.
+
+ Args:
+ a: Activation matrix [M, K], BF16.
+ b_fp8: FP8 E4M3 weight matrix [K, N], stored as uint8.
+ b_scale: Block-wise scale factors [K/128, N/128], BF16.
+ out: Optional output matrix [M, N], BF16.
+
+ Returns:
+ Output matrix [M, N], BF16.
+ """
+ from pygpukit.core.dtypes import bfloat16, uint8
+
+ if a.ndim != 2:
+ raise ValueError(f"w8a16_gemm_sm120 requires 2D input matrix, got {a.ndim}D")
+
+ if b_fp8.ndim != 2:
+ raise ValueError(f"w8a16_gemm_sm120 requires 2D weight matrix, got {b_fp8.ndim}D")
+
+ if a.dtype != bfloat16:
+ raise ValueError(f"w8a16_gemm_sm120 requires bfloat16 activation, got {a.dtype}")
+
+ if b_fp8.dtype != uint8:
+ raise ValueError(f"w8a16_gemm_sm120 requires uint8 (FP8) weights, got {b_fp8.dtype}")
+
+ if b_scale.dtype != bfloat16:
+ raise ValueError(f"w8a16_gemm_sm120 requires bfloat16 scale, got {b_scale.dtype}")
+
+ M = a.shape[0]
+ K = a.shape[1]
+ if b_fp8.shape[0] != K:
+ raise ValueError(
+ f"w8a16_gemm_sm120 dimension mismatch: A[{M},{K}] vs B[{b_fp8.shape[0]}, {b_fp8.shape[1]}]"
+ )
+
+ N = b_fp8.shape[1]
+
+ # Validate output
+ if out is not None:
+ if out.shape != (M, N):
+ raise ValueError(f"out shape {out.shape} does not match expected ({M}, {N})")
+ if out.dtype != bfloat16:
+ raise ValueError(f"out dtype {out.dtype} must be bfloat16")
+
+ # Initialize W8A16 GEMM LUT (runtime initialization to avoid symbol conflicts)
+ w8a16_gemm_init_lut()
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+
+ a_native = a._get_native()
+ b_fp8_native = b_fp8._get_native()
+ b_scale_native = b_scale._get_native()
+
+ if out is None:
+ out_native = native.empty([M, N], native.DataType.BFloat16)
+ out = GPUArray._wrap_native(out_native)
+ else:
+ out_native = out._get_native()
+
+ native.gemm_w8a16_bf16_sm120(a_native, b_fp8_native, b_scale_native, out_native)
+
+ return out
+ else:
+ raise NotImplementedError("W8A16 GEMM requires native GPU backend with SM120")
+
+
+gemm_w8a16_bf16_sm120 = w8a16_gemm_sm120
+
+
+__all__ = [
+ "w8a16_gemm_init_lut",
+ "gemm_w8a16_init_lut",
+ "w8a16_gemm_sm120",
+ "gemm_w8a16_bf16_sm120",
+]
diff --git a/src/pygpukit/ops/nn.py b/src/pygpukit/ops/nn.py
deleted file mode 100644
index ecf6f8f..0000000
--- a/src/pygpukit/ops/nn.py
+++ /dev/null
@@ -1,1016 +0,0 @@
-"""Neural network operations for GPUArrays.
-
-Corresponds to native/ops/nn/.
-"""
-
-from __future__ import annotations
-
-import numpy as np
-
-from pygpukit.core.array import GPUArray
-from pygpukit.core.backend import NativeBackend, get_backend
-from pygpukit.core.factory import from_numpy
-from pygpukit.ops._common import _validate_float_dtype
-
-# =============================================================================
-# Activation Functions
-# =============================================================================
-
-
-def gelu(a: GPUArray) -> GPUArray:
- """GELU (Gaussian Error Linear Unit) activation.
-
- Computes: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
-
- Args:
- a: Input array (float32, float64, float16, or bfloat16).
-
- Returns:
- A new GPUArray containing gelu(a).
-
- Raises:
- ValueError: If dtype is not a float type.
- """
- _validate_float_dtype(a, "gelu")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _gelu_native(a)
- else:
- return _gelu_cpu(a)
-
-
-def _gelu_cpu(a: GPUArray) -> GPUArray:
- """CPU implementation of gelu."""
- a_np = a.to_numpy()
- # GELU approximation
- x = a_np.astype(np.float32) if a_np.dtype in [np.float16] else a_np
- c1 = 0.7978845608 # sqrt(2/pi)
- c2 = 0.044715
- result = x * 0.5 * (1 + np.tanh(c1 * (x + c2 * x**3)))
- return from_numpy(result.astype(a_np.dtype))
-
-
-def _gelu_native(a: GPUArray) -> GPUArray:
- """Native C++ CUDA implementation of gelu (zero-copy)."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- a_native = a._get_native()
- c_native = native.gelu(a_native)
- return GPUArray._wrap_native(c_native)
-
-
-def silu(a: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
- """SiLU (Swish) activation: y = x * sigmoid(x).
-
- Used in Llama and other modern LLMs as the activation in MLP layers.
-
- Args:
- a: Input array.
- out: Optional pre-allocated output array. If provided, the result
- is written to this array (for CUDA Graph capture support).
-
- Returns:
- A new GPUArray containing the SiLU-activated values, or the out array if provided.
-
- Raises:
- ValueError: If dtype is not a float type.
- """
- _validate_float_dtype(a, "silu")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _silu_native(a, out=out)
- else:
- return _silu_cpu(a)
-
-
-def _silu_cpu(a: GPUArray) -> GPUArray:
- """CPU implementation of SiLU."""
- x = a.to_numpy()
- # SiLU = x * sigmoid(x) = x / (1 + exp(-x))
- result = x / (1.0 + np.exp(-x))
- return from_numpy(result)
-
-
-def _silu_native(a: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
- """Native C++ CUDA implementation of SiLU (zero-copy)."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- a_native = a._get_native()
-
- if out is not None:
- out_native = out._get_native()
- native.silu_(a_native, out_native)
- return out
- else:
- c_native = native.silu(a_native)
- return GPUArray._wrap_native(c_native)
-
-
-def sigmoid(a: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
- """Sigmoid activation: y = 1 / (1 + exp(-x)).
-
- Args:
- a: Input array.
- out: Optional pre-allocated output array.
-
- Returns:
- A new GPUArray containing the sigmoid-activated values.
- """
- _validate_float_dtype(a, "sigmoid")
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- a_native = a._get_native()
-
- if out is not None:
- out_native = out._get_native()
- native.sigmoid_(a_native, out_native)
- return out
- else:
- return GPUArray._wrap_native(native.sigmoid(a_native))
- else:
- x = a.to_numpy()
- result = 1.0 / (1.0 + np.exp(-x))
- return from_numpy(result)
-
-
-def tanh(a: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
- """Tanh activation.
-
- Args:
- a: Input array.
- out: Optional pre-allocated output array.
-
- Returns:
- A new GPUArray containing the tanh-activated values.
- """
- _validate_float_dtype(a, "tanh")
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- a_native = a._get_native()
-
- if out is not None:
- out_native = out._get_native()
- native.tanh_(a_native, out_native)
- return out
- else:
- return GPUArray._wrap_native(native.tanh(a_native))
- else:
- x = a.to_numpy()
- return from_numpy(np.tanh(x))
-
-
-# =============================================================================
-# Normalization Layers
-# =============================================================================
-
-
-def layernorm(
- input: GPUArray,
- gamma: GPUArray,
- beta: GPUArray,
- eps: float = 1e-5,
-) -> GPUArray:
- """Layer normalization.
-
- Computes: (x - mean) / sqrt(var + eps) * gamma + beta
-
- Args:
- input: Input array of shape [batch, features] or [batch, seq_len, features].
- gamma: Scale parameter of shape [features].
- beta: Bias parameter of shape [features].
- eps: Small epsilon for numerical stability.
-
- Returns:
- A new GPUArray containing the normalized output.
-
- Raises:
- ValueError: If shapes or dtypes don't match.
- """
- _validate_float_dtype(input, "layernorm")
-
- if input.ndim not in (2, 3):
- raise ValueError(f"layernorm expects 2D or 3D input, got {input.ndim}D")
- if gamma.ndim != 1 or beta.ndim != 1:
- raise ValueError("layernorm expects 1D gamma and beta")
- if input.dtype != gamma.dtype or input.dtype != beta.dtype:
- raise ValueError("layernorm: all inputs must have same dtype")
-
- features = input.shape[-1] # Last dimension is features
- if gamma.shape[0] != features or beta.shape[0] != features:
- raise ValueError(
- f"layernorm: gamma/beta size {gamma.shape[0]} must match features {features}"
- )
-
- # Handle 3D input by reshaping to 2D, processing, and reshaping back
- if input.ndim == 3:
- batch, seq_len, feat = input.shape
- input_2d = input.reshape(batch * seq_len, feat)
- result_2d = _layernorm_dispatch(input_2d, gamma, beta, eps)
- return result_2d.reshape(batch, seq_len, feat)
- else:
- return _layernorm_dispatch(input, gamma, beta, eps)
-
-
-def _layernorm_dispatch(
- input: GPUArray,
- gamma: GPUArray,
- beta: GPUArray,
- eps: float,
-) -> GPUArray:
- """Dispatch layernorm to native or CPU implementation."""
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _layernorm_native(input, gamma, beta, eps)
- else:
- return _layernorm_cpu(input, gamma, beta, eps)
-
-
-def _layernorm_cpu(
- input: GPUArray,
- gamma: GPUArray,
- beta: GPUArray,
- eps: float,
-) -> GPUArray:
- """CPU implementation of layernorm."""
- x = input.to_numpy()
- g = gamma.to_numpy()
- b = beta.to_numpy()
-
- # Compute mean and variance along features axis
- mean = x.mean(axis=1, keepdims=True)
- var = x.var(axis=1, keepdims=True)
-
- # Normalize
- normalized = (x - mean) / np.sqrt(var + eps)
-
- # Apply affine transform
- result = normalized * g + b
- return from_numpy(result)
-
-
-def _layernorm_native(
- input: GPUArray,
- gamma: GPUArray,
- beta: GPUArray,
- eps: float,
-) -> GPUArray:
- """Native C++ CUDA implementation of layernorm (zero-copy)."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- input_native = input._get_native()
- gamma_native = gamma._get_native()
- beta_native = beta._get_native()
- c_native = native.layernorm(input_native, gamma_native, beta_native, eps)
- return GPUArray._wrap_native(c_native)
-
-
-def rmsnorm(
- input: GPUArray,
- gamma: GPUArray,
- eps: float = 1e-5,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """RMS Normalization (Root Mean Square Normalization).
-
- Computes: x / sqrt(mean(x^2) + eps) * gamma
-
- Simpler than LayerNorm (no mean subtraction, no beta).
- Used in Llama and other modern LLMs.
-
- Args:
- input: Input array of shape [batch, features].
- gamma: Scale parameter of shape [features].
- eps: Small epsilon for numerical stability.
- out: Optional output buffer. If provided, result is written in-place
- (for CUDA Graph capture).
-
- Returns:
- A new GPUArray containing the normalized output (or out if provided).
-
- Raises:
- ValueError: If shapes or dtypes don't match.
- """
- _validate_float_dtype(input, "rmsnorm")
-
- if input.ndim != 2:
- raise ValueError(f"rmsnorm expects 2D input [batch, features], got {input.ndim}D")
- if gamma.ndim != 1:
- raise ValueError("rmsnorm expects 1D gamma")
- if input.dtype != gamma.dtype:
- raise ValueError("rmsnorm: all inputs must have same dtype")
-
- features = input.shape[1]
- if gamma.shape[0] != features:
- raise ValueError(f"rmsnorm: gamma size {gamma.shape[0]} must match features {features}")
-
- # Validate out array if provided
- if out is not None:
- if out.shape != input.shape:
- raise ValueError(f"out shape {out.shape} does not match input shape {input.shape}")
- if out.dtype != input.dtype:
- raise ValueError(f"out dtype {out.dtype} does not match input dtype {input.dtype}")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _rmsnorm_native(input, gamma, eps, out=out)
- else:
- return _rmsnorm_cpu(input, gamma, eps, out=out)
-
-
-def _rmsnorm_cpu(
- input: GPUArray,
- gamma: GPUArray,
- eps: float,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """CPU implementation of rmsnorm."""
- x = input.to_numpy()
- g = gamma.to_numpy()
-
- # RMS = sqrt(mean(x^2) + eps)
- rms = np.sqrt(np.mean(x**2, axis=1, keepdims=True) + eps)
-
- # Normalize and scale
- result = (x / rms) * g
-
- if out is not None:
- out_np = out.to_numpy()
- np.copyto(out_np, result)
- out._data = from_numpy(out_np)._data
- return out
- return from_numpy(result)
-
-
-def _rmsnorm_native(
- input: GPUArray,
- gamma: GPUArray,
- eps: float,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Native C++ CUDA implementation of rmsnorm (zero-copy)."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- input_native = input._get_native()
- gamma_native = gamma._get_native()
-
- if out is not None:
- out_native = out._get_native()
- native.rmsnorm_(input_native, gamma_native, out_native, eps)
- return out
- else:
- c_native = native.rmsnorm(input_native, gamma_native, eps)
- return GPUArray._wrap_native(c_native)
-
-
-# =============================================================================
-# Bias Operations
-# =============================================================================
-
-
-def bias_add_inplace(output: GPUArray, bias: GPUArray) -> None:
- """Add bias to output in-place.
-
- Computes: output[batch, features] += bias[features]
-
- Args:
- output: Output array of shape [batch, features] (modified in-place).
- bias: Bias array of shape [features].
-
- Raises:
- ValueError: If shapes don't match or dtypes don't match.
- """
- _validate_float_dtype(output, "bias_add_inplace")
-
- if output.ndim != 2:
- raise ValueError(
- f"bias_add_inplace expects 2D output [batch, features], got {output.ndim}D"
- )
- if bias.ndim != 1:
- raise ValueError(f"bias_add_inplace expects 1D bias [features], got {bias.ndim}D")
- if output.dtype != bias.dtype:
- raise ValueError("bias_add_inplace: output and bias must have same dtype")
-
- features = output.shape[1]
- if bias.shape[0] != features:
- raise ValueError(
- f"bias_add_inplace: bias size {bias.shape[0]} must match features {features}"
- )
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- _bias_add_inplace_native(output, bias)
- else:
- _bias_add_inplace_cpu(output, bias)
-
-
-def _bias_add_inplace_cpu(output: GPUArray, bias: GPUArray) -> None:
- """CPU implementation of bias_add_inplace."""
- # For CPU backend, we need to get numpy arrays, modify, and update
- output_np = output.to_numpy()
- bias_np = bias.to_numpy()
- output_np += bias_np
- # Note: This creates a new array - for CPU backend, in-place is not truly in-place
- # The native backend does true in-place modification
- output._data = from_numpy(output_np)._data
-
-
-def _bias_add_inplace_native(output: GPUArray, bias: GPUArray) -> None:
- """Native C++ CUDA implementation of bias_add_inplace (true in-place)."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- output_native = output._get_native()
- bias_native = bias._get_native()
- native.bias_add_inplace(output_native, bias_native)
-
-
-# =============================================================================
-# Attention Operations
-# =============================================================================
-
-
-def sdpa_causal(
- Q: GPUArray,
- K: GPUArray,
- V: GPUArray,
- scale: float = 0.0,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Scaled Dot-Product Attention with causal mask.
-
- Computes attention with automatic causal masking for autoregressive
- sequence generation. This is the core attention operation used in
- transformer models.
-
- Algorithm:
- scores = Q @ K^T / scale
- scores = apply_causal_mask(scores)
- weights = softmax(scores)
- output = weights @ V
-
- Args:
- Q: Query tensor of shape [n_heads, q_len, head_dim].
- K: Key tensor of shape [n_heads, kv_len, head_dim].
- V: Value tensor of shape [n_heads, kv_len, head_dim].
- scale: Scaling factor (typically 1/sqrt(head_dim)).
- If <= 0, computed automatically from head_dim.
- out: Optional output buffer [n_heads, q_len, head_dim].
- If provided, result is written in-place (for CUDA Graph capture).
-
- Returns:
- Output tensor of shape [n_heads, q_len, head_dim].
-
- Raises:
- ValueError: If shapes or dtypes don't match.
-
- Note:
- For KV cache usage during inference, kv_len >= q_len.
- The causal mask ensures query at position i can only attend
- to key positions 0 to (kv_len - q_len + i).
- """
- _validate_float_dtype(Q, "sdpa_causal")
-
- if Q.ndim != 3 or K.ndim != 3 or V.ndim != 3:
- raise ValueError("sdpa_causal expects 3D inputs [n_heads, seq_len, head_dim]")
- if Q.dtype != K.dtype or Q.dtype != V.dtype:
- raise ValueError("sdpa_causal: Q, K, V must have same dtype")
-
- n_heads, q_len, head_dim = Q.shape
-
- if K.shape[0] != n_heads or V.shape[0] != n_heads:
- raise ValueError("sdpa_causal: n_heads mismatch")
- if K.shape[2] != head_dim or V.shape[2] != head_dim:
- raise ValueError("sdpa_causal: head_dim mismatch")
- if K.shape[1] != V.shape[1]:
- raise ValueError("sdpa_causal: K and V seq_len mismatch")
-
- # Validate out array if provided
- if out is not None:
- if out.shape != (n_heads, q_len, head_dim):
- raise ValueError(
- f"out shape {out.shape} does not match expected {(n_heads, q_len, head_dim)}"
- )
- if out.dtype != Q.dtype:
- raise ValueError(f"out dtype {out.dtype} does not match Q dtype {Q.dtype}")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- return _sdpa_causal_native(Q, K, V, scale, out=out)
- else:
- return _sdpa_causal_cpu(Q, K, V, scale, out=out)
-
-
-def _sdpa_causal_cpu(
- Q: GPUArray,
- K: GPUArray,
- V: GPUArray,
- scale: float,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """CPU implementation of SDPA with causal mask."""
- q = Q.to_numpy()
- k = K.to_numpy()
- v = V.to_numpy()
-
- n_heads, q_len, head_dim = q.shape
- kv_len = k.shape[1]
-
- if scale <= 0:
- scale = 1.0 / np.sqrt(head_dim)
-
- # scores: [n_heads, q_len, kv_len]
- scores = np.matmul(q, k.transpose(0, 2, 1)) * scale
-
- # Create causal mask
- causal_offset = kv_len - q_len
- for i in range(q_len):
- max_attend = causal_offset + i + 1
- if max_attend < kv_len:
- scores[:, i, max_attend:] = -np.inf
-
- # Softmax over last dimension
- scores_max = scores.max(axis=-1, keepdims=True)
- exp_scores = np.exp(scores - scores_max)
- weights = exp_scores / exp_scores.sum(axis=-1, keepdims=True)
-
- # output: [n_heads, q_len, head_dim]
- output = np.matmul(weights, v)
-
- if out is not None:
- out_np = out.to_numpy()
- np.copyto(out_np, output.astype(q.dtype))
- out._data = from_numpy(out_np)._data
- return out
- return from_numpy(output.astype(q.dtype))
-
-
-def _sdpa_causal_native(
- Q: GPUArray,
- K: GPUArray,
- V: GPUArray,
- scale: float,
- *,
- out: GPUArray | None = None,
-) -> GPUArray:
- """Native C++ CUDA implementation of SDPA with causal mask."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- q_native = Q._get_native()
- k_native = K._get_native()
- v_native = V._get_native()
-
- if out is not None:
- out_native = out._get_native()
- native.sdpa_causal_(q_native, k_native, v_native, out_native, scale)
- return out
- else:
- c_native = native.sdpa_causal(q_native, k_native, v_native, scale)
- return GPUArray._wrap_native(c_native)
-
-
-def sdpa_causal_fixed_cache(
- Q: GPUArray,
- K: GPUArray,
- V: GPUArray,
- out: GPUArray,
- context_len: int,
- scale: float = 0.0,
-) -> None:
- """SDPA with fixed-length KV cache for CUDA Graph capture.
-
- This variant is designed for use with pre-allocated KV caches where
- the buffer size (max_seq_len) is larger than the actual context length.
-
- Args:
- Q: Query tensor of shape [n_heads, q_len, head_dim].
- K: Key cache of shape [n_heads, max_seq_len, head_dim].
- V: Value cache of shape [n_heads, max_seq_len, head_dim].
- out: Pre-allocated output buffer [n_heads, q_len, head_dim].
- context_len: Actual number of valid tokens in KV cache.
- scale: Scaling factor (typically 1/sqrt(head_dim)).
- If <= 0, computed automatically from head_dim.
-
- Raises:
- ValueError: If shapes or dtypes don't match, or context_len is invalid.
- """
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- q_native = Q._get_native()
- k_native = K._get_native()
- v_native = V._get_native()
- out_native = out._get_native()
-
- native.sdpa_causal_fixed_cache(q_native, k_native, v_native, out_native, context_len, scale)
-
-
-def sdpa_causal_fixed_cache_ptr(
- Q: GPUArray,
- K: GPUArray,
- V: GPUArray,
- out: GPUArray,
- context_len_buf: GPUArray,
- max_kv_len: int,
- scale: float = 0.0,
-) -> None:
- """SDPA with pointer-based context_len for CUDA Graph replay.
-
- This variant reads context_len from a GPU buffer at runtime, enabling
- CUDA Graph replay with dynamic context lengths without re-capture.
-
- Args:
- Q: Query tensor of shape [n_heads, q_len, head_dim].
- K: Key cache of shape [n_heads, max_seq_len, head_dim].
- V: Value cache of shape [n_heads, max_seq_len, head_dim].
- out: Pre-allocated output buffer [n_heads, q_len, head_dim].
- context_len_buf: GPU int32 buffer containing actual context_len [1].
- max_kv_len: Maximum context length (for shared memory allocation
- during graph capture). Must be <= K.shape[1].
- scale: Scaling factor (typically 1/sqrt(head_dim)).
- If <= 0, computed automatically from head_dim.
-
- Note:
- For CUDA Graph: capture with max_kv_len, then update context_len_buf
- before each replay to change the effective context length.
- """
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- q_native = Q._get_native()
- k_native = K._get_native()
- v_native = V._get_native()
- out_native = out._get_native()
- ctx_buf_native = context_len_buf._get_native()
-
- native.sdpa_causal_fixed_cache_ptr(
- q_native, k_native, v_native, out_native, ctx_buf_native, max_kv_len, scale
- )
-
-
-# =============================================================================
-# RoPE (Rotary Position Embedding)
-# =============================================================================
-
-
-def rope_inplace(
- q: GPUArray,
- k: GPUArray,
- cos: GPUArray,
- sin: GPUArray,
-) -> None:
- """Apply Rotary Position Embedding (RoPE) to Q and K tensors in-place.
-
- Args:
- q: Query tensor of shape [seq_len, n_heads_q, head_dim] (modified in-place).
- k: Key tensor of shape [seq_len, n_heads_k, head_dim] (modified in-place).
- cos: Precomputed cosine of shape [seq_len, head_dim].
- sin: Precomputed sine of shape [seq_len, head_dim].
-
- Note:
- This operation modifies q and k in-place.
- Works with GQA (n_heads_k can be different from n_heads_q).
- """
- _validate_float_dtype(q, "rope_inplace")
-
- if q.ndim != 3 or k.ndim != 3:
- raise ValueError("rope_inplace expects 3D q, k [seq_len, n_heads, head_dim]")
- if cos.ndim != 2 or sin.ndim != 2:
- raise ValueError("rope_inplace expects 2D cos, sin [seq_len, head_dim]")
-
- backend = get_backend()
-
- if isinstance(backend, NativeBackend) and backend.is_available():
- _rope_inplace_native(q, k, cos, sin)
- else:
- _rope_inplace_cpu(q, k, cos, sin)
-
-
-def _rope_inplace_cpu(
- q: GPUArray,
- k: GPUArray,
- cos: GPUArray,
- sin: GPUArray,
-) -> None:
- """CPU implementation of rope_inplace."""
-
- q_np = q.to_numpy()
- k_np = k.to_numpy()
- cos_np = cos.to_numpy()
- sin_np = sin.to_numpy()
-
- seq_len, n_heads_q, head_dim = q_np.shape
- n_heads_k = k_np.shape[1]
- half_dim = head_dim // 2
-
- # Apply RoPE to Q
- for s in range(seq_len):
- c = cos_np[s, :half_dim]
- sn = sin_np[s, :half_dim]
- for h in range(n_heads_q):
- q0 = q_np[s, h, :half_dim].copy()
- q1 = q_np[s, h, half_dim:].copy()
- q_np[s, h, :half_dim] = q0 * c - q1 * sn
- q_np[s, h, half_dim:] = q1 * c + q0 * sn
-
- # Apply RoPE to K
- for s in range(seq_len):
- c = cos_np[s, :half_dim]
- sn = sin_np[s, :half_dim]
- for h in range(n_heads_k):
- k0 = k_np[s, h, :half_dim].copy()
- k1 = k_np[s, h, half_dim:].copy()
- k_np[s, h, :half_dim] = k0 * c - k1 * sn
- k_np[s, h, half_dim:] = k1 * c + k0 * sn
-
- # Update the GPUArray data in-place
- q._data = from_numpy(q_np)._data
- k._data = from_numpy(k_np)._data
-
-
-def _rope_inplace_native(
- q: GPUArray,
- k: GPUArray,
- cos: GPUArray,
- sin: GPUArray,
-) -> None:
- """Native C++ CUDA implementation of rope_inplace."""
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- q_native = q._get_native()
- k_native = k._get_native()
- cos_native = cos._get_native()
- sin_native = sin._get_native()
- native.rope_inplace(q_native, k_native, cos_native, sin_native)
-
-
-def rope_inplace_f32table(
- q: GPUArray,
- k: GPUArray,
- cos: GPUArray,
- sin: GPUArray,
-) -> None:
- """Apply RoPE with FP32 cos/sin tables (higher precision for bf16/f16).
-
- Uses FP32 cos/sin tables for higher precision computation, avoiding
- the need to convert tables to bf16/f16.
-
- Args:
- q: Query tensor [seq_len, n_heads_q, head_dim] (bf16 or f16, modified in-place).
- k: Key tensor [seq_len, n_heads_k, head_dim] (bf16 or f16, modified in-place).
- cos: Precomputed cosine [seq_len, head_dim] (f32).
- sin: Precomputed sine [seq_len, head_dim] (f32).
- """
- from pygpukit.core.backend import get_native_module
-
- native = get_native_module()
- q_native = q._get_native()
- k_native = k._get_native()
- cos_native = cos._get_native()
- sin_native = sin._get_native()
- native.rope_inplace_f32table(q_native, k_native, cos_native, sin_native)
-
-
-# =============================================================================
-# QKV Split Operations
-# =============================================================================
-
-
-def split_qkv_batch(
- qkv: GPUArray,
- q_out: GPUArray,
- k_out: GPUArray,
- v_out: GPUArray,
- q_dim: int,
- k_dim: int,
- v_dim: int,
-) -> None:
- """Split fused QKV projection output into separate Q, K, V tensors.
-
- This is a zero-allocation operation designed for CUDA Graph compatibility.
- Output buffers must be pre-allocated.
-
- Args:
- qkv: Fused QKV tensor [seq_len, q_dim + k_dim + v_dim].
- q_out: Pre-allocated Q output buffer [seq_len, q_dim] or [seq_len, n_heads, head_dim].
- k_out: Pre-allocated K output buffer [seq_len, k_dim] or [seq_len, n_kv_heads, head_dim].
- v_out: Pre-allocated V output buffer [seq_len, v_dim] or [seq_len, n_kv_heads, head_dim].
- q_dim: Size of Q projection (num_heads * head_dim).
- k_dim: Size of K projection (num_kv_heads * head_dim).
- v_dim: Size of V projection (num_kv_heads * head_dim).
-
- Note:
- The output buffers can be 2D [seq_len, dim] or 3D [seq_len, heads, head_dim]
- as long as the total size matches. The kernel writes linearly.
- """
- from pygpukit.core.backend import get_backend, get_native_module
-
- backend = get_backend()
- if not backend.is_available():
- raise RuntimeError("split_qkv_batch requires GPU backend")
-
- native = get_native_module()
- native.split_qkv_batch(
- qkv._get_native(),
- q_out._get_native(),
- k_out._get_native(),
- v_out._get_native(),
- q_dim,
- k_dim,
- v_dim,
- )
-
-
-def slice_rows_range_ptr(
- table: GPUArray,
- out: GPUArray,
- start_pos_buf: GPUArray,
- count: int,
-) -> None:
- """Slice consecutive rows from table using GPU-stored start position.
-
- This is a zero-allocation operation designed for CUDA Graph compatibility.
- The start position is read from a GPU buffer, enabling graph replay with
- different positions without H2D copies.
-
- Args:
- table: Source table of shape [num_rows, row_dim].
- out: Pre-allocated output buffer of shape [count, row_dim].
- start_pos_buf: GPU buffer containing start position [1] int32.
- count: Number of consecutive rows to copy.
-
- Example:
- # During CUDA Graph capture
- slice_rows_range_ptr(rope_cos_table, cos_batch, start_pos_buf, batch_size)
- # Copies cos_batch[i, :] = rope_cos_table[start_pos + i, :]
- """
- from pygpukit.core.backend import get_backend, get_native_module
-
- backend = get_backend()
- if not backend.is_available():
- raise RuntimeError("slice_rows_range_ptr requires GPU backend")
-
- native = get_native_module()
- native.slice_rows_range_ptr(
- table._get_native(),
- out._get_native(),
- start_pos_buf._get_native(),
- count,
- )
-
-
-# =============================================================================
-# LSTM (Recurrent) Operations
-# =============================================================================
-
-
-def lstm_forward(
- x: GPUArray,
- W_ih: GPUArray,
- W_hh: GPUArray,
- b_ih: GPUArray,
- b_hh: GPUArray,
- h0: GPUArray | None = None,
- c0: GPUArray | None = None,
- reverse: bool = False,
-) -> tuple[GPUArray, GPUArray, GPUArray]:
- """LSTM forward pass (unidirectional).
-
- Implements the standard LSTM equations:
- i_t = sigmoid(W_ii @ x_t + b_ii + W_hi @ h_{t-1} + b_hi)
- f_t = sigmoid(W_if @ x_t + b_if + W_hf @ h_{t-1} + b_hf)
- g_t = tanh(W_ig @ x_t + b_ig + W_hg @ h_{t-1} + b_hg)
- o_t = sigmoid(W_io @ x_t + b_io + W_ho @ h_{t-1} + b_ho)
- c_t = f_t * c_{t-1} + i_t * g_t
- h_t = o_t * tanh(c_t)
-
- Args:
- x: Input sequence [batch, seq_len, input_size].
- W_ih: Input-to-hidden weights [4*hidden_size, input_size].
- W_hh: Hidden-to-hidden weights [4*hidden_size, hidden_size].
- b_ih: Input bias [4*hidden_size].
- b_hh: Hidden bias [4*hidden_size].
- h0: Initial hidden state [batch, hidden_size]. If None, zeros.
- c0: Initial cell state [batch, hidden_size]. If None, zeros.
- reverse: If True, process sequence in reverse order.
-
- Returns:
- Tuple of (output, h_n, c_n):
- output: Hidden states [batch, seq_len, hidden_size]
- h_n: Final hidden state [batch, hidden_size]
- c_n: Final cell state [batch, hidden_size]
- """
- from pygpukit.core.backend import get_backend, get_native_module
-
- backend = get_backend()
- if not backend.is_available():
- raise RuntimeError("lstm_forward requires GPU backend")
-
- native = get_native_module()
-
- # Create zero-sized arrays for None states
- if h0 is None:
- h0_native = native.GPUArray([0], native.Float32)
- else:
- h0_native = h0._get_native()
-
- if c0 is None:
- c0_native = native.GPUArray([0], native.Float32)
- else:
- c0_native = c0._get_native()
-
- output_native, h_n_native, c_n_native = native.lstm_forward(
- x._get_native(),
- W_ih._get_native(),
- W_hh._get_native(),
- b_ih._get_native(),
- b_hh._get_native(),
- h0_native,
- c0_native,
- reverse,
- )
-
- return (
- GPUArray._wrap_native(output_native),
- GPUArray._wrap_native(h_n_native),
- GPUArray._wrap_native(c_n_native),
- )
-
-
-def lstm_bidirectional(
- x: GPUArray,
- W_ih_fwd: GPUArray,
- W_hh_fwd: GPUArray,
- b_ih_fwd: GPUArray,
- b_hh_fwd: GPUArray,
- W_ih_bwd: GPUArray,
- W_hh_bwd: GPUArray,
- b_ih_bwd: GPUArray,
- b_hh_bwd: GPUArray,
-) -> tuple[GPUArray, GPUArray, GPUArray]:
- """Bidirectional LSTM.
-
- Runs forward and backward LSTM passes and concatenates the outputs.
-
- Args:
- x: Input sequence [batch, seq_len, input_size].
- W_ih_fwd, W_hh_fwd, b_ih_fwd, b_hh_fwd: Forward LSTM weights.
- W_ih_bwd, W_hh_bwd, b_ih_bwd, b_hh_bwd: Backward LSTM weights.
-
- Returns:
- Tuple of (output, h_n, c_n):
- output: Concatenated hidden states [batch, seq_len, 2*hidden_size]
- h_n: Stacked final hidden states [2, batch, hidden_size]
- c_n: Stacked final cell states [2, batch, hidden_size]
- """
- from pygpukit.core.backend import get_backend, get_native_module
-
- backend = get_backend()
- if not backend.is_available():
- raise RuntimeError("lstm_bidirectional requires GPU backend")
-
- native = get_native_module()
-
- output_native, h_n_native, c_n_native = native.lstm_bidirectional(
- x._get_native(),
- W_ih_fwd._get_native(),
- W_hh_fwd._get_native(),
- b_ih_fwd._get_native(),
- b_hh_fwd._get_native(),
- W_ih_bwd._get_native(),
- W_hh_bwd._get_native(),
- b_ih_bwd._get_native(),
- b_hh_bwd._get_native(),
- )
-
- return (
- GPUArray._wrap_native(output_native),
- GPUArray._wrap_native(h_n_native),
- GPUArray._wrap_native(c_n_native),
- )
diff --git a/src/pygpukit/ops/nn/__init__.py b/src/pygpukit/ops/nn/__init__.py
new file mode 100644
index 0000000..3e11fdc
--- /dev/null
+++ b/src/pygpukit/ops/nn/__init__.py
@@ -0,0 +1,79 @@
+"""Neural network operations for GPUArrays.
+
+Corresponds to native/ops/nn/.
+
+Provides:
+- Activation functions (gelu, silu, sigmoid, tanh)
+- Normalization layers (layernorm, rmsnorm)
+- Attention operations (sdpa_causal, sdpa_causal_fixed_cache)
+- RoPE (rotary position embedding)
+- Linear operations (bias_add_inplace, split_qkv_batch)
+- Recurrent operations (lstm_forward, lstm_bidirectional)
+"""
+
+from __future__ import annotations
+
+# Activation functions
+from pygpukit.ops.nn.activation import (
+ gelu,
+ sigmoid,
+ silu,
+ tanh,
+)
+
+# Attention operations
+from pygpukit.ops.nn.attention import (
+ sdpa_causal,
+ sdpa_causal_fixed_cache,
+ sdpa_causal_fixed_cache_ptr,
+)
+
+# Linear operations
+from pygpukit.ops.nn.linear import (
+ bias_add_inplace,
+ slice_rows_range_ptr,
+ split_qkv_batch,
+)
+
+# Normalization layers
+from pygpukit.ops.nn.norm import (
+ layernorm,
+ rmsnorm,
+)
+
+# Recurrent operations
+from pygpukit.ops.nn.recurrent import (
+ lstm_bidirectional,
+ lstm_forward,
+)
+
+# RoPE operations
+from pygpukit.ops.nn.rope import (
+ rope_inplace,
+ rope_inplace_f32table,
+)
+
+__all__ = [
+ # Activation
+ "gelu",
+ "silu",
+ "sigmoid",
+ "tanh",
+ # Normalization
+ "layernorm",
+ "rmsnorm",
+ # Attention
+ "sdpa_causal",
+ "sdpa_causal_fixed_cache",
+ "sdpa_causal_fixed_cache_ptr",
+ # RoPE
+ "rope_inplace",
+ "rope_inplace_f32table",
+ # Linear
+ "bias_add_inplace",
+ "split_qkv_batch",
+ "slice_rows_range_ptr",
+ # Recurrent
+ "lstm_forward",
+ "lstm_bidirectional",
+]
diff --git a/src/pygpukit/ops/nn/activation.py b/src/pygpukit/ops/nn/activation.py
new file mode 100644
index 0000000..266b82d
--- /dev/null
+++ b/src/pygpukit/ops/nn/activation.py
@@ -0,0 +1,177 @@
+"""Activation functions for GPUArrays.
+
+Corresponds to native/ops/nn/activation/.
+"""
+
+from __future__ import annotations
+
+import numpy as np
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.backend import NativeBackend, get_backend
+from pygpukit.core.factory import from_numpy
+from pygpukit.ops._common import _validate_float_dtype
+
+
+def gelu(a: GPUArray) -> GPUArray:
+ """GELU (Gaussian Error Linear Unit) activation.
+
+ Computes: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
+
+ Args:
+ a: Input array (float32, float64, float16, or bfloat16).
+
+ Returns:
+ A new GPUArray containing gelu(a).
+
+ Raises:
+ ValueError: If dtype is not a float type.
+ """
+ _validate_float_dtype(a, "gelu")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ return _gelu_native(a)
+ else:
+ return _gelu_cpu(a)
+
+
+def _gelu_cpu(a: GPUArray) -> GPUArray:
+ """CPU implementation of gelu."""
+ a_np = a.to_numpy()
+ # GELU approximation
+ x = a_np.astype(np.float32) if a_np.dtype in [np.float16] else a_np
+ c1 = 0.7978845608 # sqrt(2/pi)
+ c2 = 0.044715
+ result = x * 0.5 * (1 + np.tanh(c1 * (x + c2 * x**3)))
+ return from_numpy(result.astype(a_np.dtype))
+
+
+def _gelu_native(a: GPUArray) -> GPUArray:
+ """Native C++ CUDA implementation of gelu (zero-copy)."""
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+ c_native = native.gelu(a_native)
+ return GPUArray._wrap_native(c_native)
+
+
+def silu(a: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
+ """SiLU (Swish) activation: y = x * sigmoid(x).
+
+ Used in Llama and other modern LLMs as the activation in MLP layers.
+
+ Args:
+ a: Input array.
+ out: Optional pre-allocated output array. If provided, the result
+ is written to this array (for CUDA Graph capture support).
+
+ Returns:
+ A new GPUArray containing the SiLU-activated values, or the out array if provided.
+
+ Raises:
+ ValueError: If dtype is not a float type.
+ """
+ _validate_float_dtype(a, "silu")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ return _silu_native(a, out=out)
+ else:
+ return _silu_cpu(a)
+
+
+def _silu_cpu(a: GPUArray) -> GPUArray:
+ """CPU implementation of SiLU."""
+ x = a.to_numpy()
+ # SiLU = x * sigmoid(x) = x / (1 + exp(-x))
+ result = x / (1.0 + np.exp(-x))
+ return from_numpy(result)
+
+
+def _silu_native(a: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
+ """Native C++ CUDA implementation of SiLU (zero-copy)."""
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+
+ if out is not None:
+ out_native = out._get_native()
+ native.silu_(a_native, out_native)
+ return out
+ else:
+ c_native = native.silu(a_native)
+ return GPUArray._wrap_native(c_native)
+
+
+def sigmoid(a: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
+ """Sigmoid activation: y = 1 / (1 + exp(-x)).
+
+ Args:
+ a: Input array.
+ out: Optional pre-allocated output array.
+
+ Returns:
+ A new GPUArray containing the sigmoid-activated values.
+ """
+ _validate_float_dtype(a, "sigmoid")
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+
+ if out is not None:
+ out_native = out._get_native()
+ native.sigmoid_(a_native, out_native)
+ return out
+ else:
+ return GPUArray._wrap_native(native.sigmoid(a_native))
+ else:
+ x = a.to_numpy()
+ result = 1.0 / (1.0 + np.exp(-x))
+ return from_numpy(result)
+
+
+def tanh(a: GPUArray, *, out: GPUArray | None = None) -> GPUArray:
+ """Tanh activation.
+
+ Args:
+ a: Input array.
+ out: Optional pre-allocated output array.
+
+ Returns:
+ A new GPUArray containing the tanh-activated values.
+ """
+ _validate_float_dtype(a, "tanh")
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ a_native = a._get_native()
+
+ if out is not None:
+ out_native = out._get_native()
+ native.tanh_(a_native, out_native)
+ return out
+ else:
+ return GPUArray._wrap_native(native.tanh(a_native))
+ else:
+ x = a.to_numpy()
+ return from_numpy(np.tanh(x))
+
+
+__all__ = [
+ "gelu",
+ "silu",
+ "sigmoid",
+ "tanh",
+]
diff --git a/src/pygpukit/ops/nn/attention.py b/src/pygpukit/ops/nn/attention.py
new file mode 100644
index 0000000..aa8e556
--- /dev/null
+++ b/src/pygpukit/ops/nn/attention.py
@@ -0,0 +1,242 @@
+"""Attention operations for GPUArrays.
+
+Corresponds to native/ops/nn/attention/.
+"""
+
+from __future__ import annotations
+
+import numpy as np
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.backend import NativeBackend, get_backend
+from pygpukit.core.factory import from_numpy
+from pygpukit.ops._common import _validate_float_dtype
+
+
+def sdpa_causal(
+ Q: GPUArray,
+ K: GPUArray,
+ V: GPUArray,
+ scale: float = 0.0,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """Scaled Dot-Product Attention with causal mask.
+
+ Computes attention with automatic causal masking for autoregressive
+ sequence generation. This is the core attention operation used in
+ transformer models.
+
+ Algorithm:
+ scores = Q @ K^T / scale
+ scores = apply_causal_mask(scores)
+ weights = softmax(scores)
+ output = weights @ V
+
+ Args:
+ Q: Query tensor of shape [n_heads, q_len, head_dim].
+ K: Key tensor of shape [n_heads, kv_len, head_dim].
+ V: Value tensor of shape [n_heads, kv_len, head_dim].
+ scale: Scaling factor (typically 1/sqrt(head_dim)).
+ If <= 0, computed automatically from head_dim.
+ out: Optional output buffer [n_heads, q_len, head_dim].
+ If provided, result is written in-place (for CUDA Graph capture).
+
+ Returns:
+ Output tensor of shape [n_heads, q_len, head_dim].
+
+ Raises:
+ ValueError: If shapes or dtypes don't match.
+
+ Note:
+ For KV cache usage during inference, kv_len >= q_len.
+ The causal mask ensures query at position i can only attend
+ to key positions 0 to (kv_len - q_len + i).
+ """
+ _validate_float_dtype(Q, "sdpa_causal")
+
+ if Q.ndim != 3 or K.ndim != 3 or V.ndim != 3:
+ raise ValueError("sdpa_causal expects 3D inputs [n_heads, seq_len, head_dim]")
+ if Q.dtype != K.dtype or Q.dtype != V.dtype:
+ raise ValueError("sdpa_causal: Q, K, V must have same dtype")
+
+ n_heads, q_len, head_dim = Q.shape
+
+ if K.shape[0] != n_heads or V.shape[0] != n_heads:
+ raise ValueError("sdpa_causal: n_heads mismatch")
+ if K.shape[2] != head_dim or V.shape[2] != head_dim:
+ raise ValueError("sdpa_causal: head_dim mismatch")
+ if K.shape[1] != V.shape[1]:
+ raise ValueError("sdpa_causal: K and V seq_len mismatch")
+
+ # Validate out array if provided
+ if out is not None:
+ if out.shape != (n_heads, q_len, head_dim):
+ raise ValueError(
+ f"out shape {out.shape} does not match expected {(n_heads, q_len, head_dim)}"
+ )
+ if out.dtype != Q.dtype:
+ raise ValueError(f"out dtype {out.dtype} does not match Q dtype {Q.dtype}")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ return _sdpa_causal_native(Q, K, V, scale, out=out)
+ else:
+ return _sdpa_causal_cpu(Q, K, V, scale, out=out)
+
+
+def _sdpa_causal_cpu(
+ Q: GPUArray,
+ K: GPUArray,
+ V: GPUArray,
+ scale: float,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """CPU implementation of SDPA with causal mask."""
+ q = Q.to_numpy()
+ k = K.to_numpy()
+ v = V.to_numpy()
+
+ n_heads, q_len, head_dim = q.shape
+ kv_len = k.shape[1]
+
+ if scale <= 0:
+ scale = 1.0 / np.sqrt(head_dim)
+
+ # scores: [n_heads, q_len, kv_len]
+ scores = np.matmul(q, k.transpose(0, 2, 1)) * scale
+
+ # Create causal mask
+ causal_offset = kv_len - q_len
+ for i in range(q_len):
+ max_attend = causal_offset + i + 1
+ if max_attend < kv_len:
+ scores[:, i, max_attend:] = -np.inf
+
+ # Softmax over last dimension
+ scores_max = scores.max(axis=-1, keepdims=True)
+ exp_scores = np.exp(scores - scores_max)
+ weights = exp_scores / exp_scores.sum(axis=-1, keepdims=True)
+
+ # output: [n_heads, q_len, head_dim]
+ output = np.matmul(weights, v)
+
+ if out is not None:
+ out_np = out.to_numpy()
+ np.copyto(out_np, output.astype(q.dtype))
+ out._data = from_numpy(out_np)._data
+ return out
+ return from_numpy(output.astype(q.dtype))
+
+
+def _sdpa_causal_native(
+ Q: GPUArray,
+ K: GPUArray,
+ V: GPUArray,
+ scale: float,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """Native C++ CUDA implementation of SDPA with causal mask."""
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ q_native = Q._get_native()
+ k_native = K._get_native()
+ v_native = V._get_native()
+
+ if out is not None:
+ out_native = out._get_native()
+ native.sdpa_causal_(q_native, k_native, v_native, out_native, scale)
+ return out
+ else:
+ c_native = native.sdpa_causal(q_native, k_native, v_native, scale)
+ return GPUArray._wrap_native(c_native)
+
+
+def sdpa_causal_fixed_cache(
+ Q: GPUArray,
+ K: GPUArray,
+ V: GPUArray,
+ out: GPUArray,
+ context_len: int,
+ scale: float = 0.0,
+) -> None:
+ """SDPA with fixed-length KV cache for CUDA Graph capture.
+
+ This variant is designed for use with pre-allocated KV caches where
+ the buffer size (max_seq_len) is larger than the actual context length.
+
+ Args:
+ Q: Query tensor of shape [n_heads, q_len, head_dim].
+ K: Key cache of shape [n_heads, max_seq_len, head_dim].
+ V: Value cache of shape [n_heads, max_seq_len, head_dim].
+ out: Pre-allocated output buffer [n_heads, q_len, head_dim].
+ context_len: Actual number of valid tokens in KV cache.
+ scale: Scaling factor (typically 1/sqrt(head_dim)).
+ If <= 0, computed automatically from head_dim.
+
+ Raises:
+ ValueError: If shapes or dtypes don't match, or context_len is invalid.
+ """
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ q_native = Q._get_native()
+ k_native = K._get_native()
+ v_native = V._get_native()
+ out_native = out._get_native()
+
+ native.sdpa_causal_fixed_cache(q_native, k_native, v_native, out_native, context_len, scale)
+
+
+def sdpa_causal_fixed_cache_ptr(
+ Q: GPUArray,
+ K: GPUArray,
+ V: GPUArray,
+ out: GPUArray,
+ context_len_buf: GPUArray,
+ max_kv_len: int,
+ scale: float = 0.0,
+) -> None:
+ """SDPA with pointer-based context_len for CUDA Graph replay.
+
+ This variant reads context_len from a GPU buffer at runtime, enabling
+ CUDA Graph replay with dynamic context lengths without re-capture.
+
+ Args:
+ Q: Query tensor of shape [n_heads, q_len, head_dim].
+ K: Key cache of shape [n_heads, max_seq_len, head_dim].
+ V: Value cache of shape [n_heads, max_seq_len, head_dim].
+ out: Pre-allocated output buffer [n_heads, q_len, head_dim].
+ context_len_buf: GPU int32 buffer containing actual context_len [1].
+ max_kv_len: Maximum context length (for shared memory allocation
+ during graph capture). Must be <= K.shape[1].
+ scale: Scaling factor (typically 1/sqrt(head_dim)).
+ If <= 0, computed automatically from head_dim.
+
+ Note:
+ For CUDA Graph: capture with max_kv_len, then update context_len_buf
+ before each replay to change the effective context length.
+ """
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ q_native = Q._get_native()
+ k_native = K._get_native()
+ v_native = V._get_native()
+ out_native = out._get_native()
+ ctx_buf_native = context_len_buf._get_native()
+
+ native.sdpa_causal_fixed_cache_ptr(
+ q_native, k_native, v_native, out_native, ctx_buf_native, max_kv_len, scale
+ )
+
+
+__all__ = [
+ "sdpa_causal",
+ "sdpa_causal_fixed_cache",
+ "sdpa_causal_fixed_cache_ptr",
+]
diff --git a/src/pygpukit/ops/nn/linear.py b/src/pygpukit/ops/nn/linear.py
new file mode 100644
index 0000000..23f337e
--- /dev/null
+++ b/src/pygpukit/ops/nn/linear.py
@@ -0,0 +1,159 @@
+"""Linear layer operations for GPUArrays.
+
+Corresponds to native/ops/nn/linear/ and native/ops/nn/tensor/.
+"""
+
+from __future__ import annotations
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.backend import NativeBackend, get_backend
+from pygpukit.core.factory import from_numpy
+from pygpukit.ops._common import _validate_float_dtype
+
+
+def bias_add_inplace(output: GPUArray, bias: GPUArray) -> None:
+ """Add bias to output in-place.
+
+ Computes: output[batch, features] += bias[features]
+
+ Args:
+ output: Output array of shape [batch, features] (modified in-place).
+ bias: Bias array of shape [features].
+
+ Raises:
+ ValueError: If shapes don't match or dtypes don't match.
+ """
+ _validate_float_dtype(output, "bias_add_inplace")
+
+ if output.ndim != 2:
+ raise ValueError(
+ f"bias_add_inplace expects 2D output [batch, features], got {output.ndim}D"
+ )
+ if bias.ndim != 1:
+ raise ValueError(f"bias_add_inplace expects 1D bias [features], got {bias.ndim}D")
+ if output.dtype != bias.dtype:
+ raise ValueError("bias_add_inplace: output and bias must have same dtype")
+
+ features = output.shape[1]
+ if bias.shape[0] != features:
+ raise ValueError(
+ f"bias_add_inplace: bias size {bias.shape[0]} must match features {features}"
+ )
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ _bias_add_inplace_native(output, bias)
+ else:
+ _bias_add_inplace_cpu(output, bias)
+
+
+def _bias_add_inplace_cpu(output: GPUArray, bias: GPUArray) -> None:
+ """CPU implementation of bias_add_inplace."""
+ # For CPU backend, we need to get numpy arrays, modify, and update
+ output_np = output.to_numpy()
+ bias_np = bias.to_numpy()
+ output_np += bias_np
+ # Note: This creates a new array - for CPU backend, in-place is not truly in-place
+ # The native backend does true in-place modification
+ output._data = from_numpy(output_np)._data
+
+
+def _bias_add_inplace_native(output: GPUArray, bias: GPUArray) -> None:
+ """Native C++ CUDA implementation of bias_add_inplace (true in-place)."""
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ output_native = output._get_native()
+ bias_native = bias._get_native()
+ native.bias_add_inplace(output_native, bias_native)
+
+
+def split_qkv_batch(
+ qkv: GPUArray,
+ q_out: GPUArray,
+ k_out: GPUArray,
+ v_out: GPUArray,
+ q_dim: int,
+ k_dim: int,
+ v_dim: int,
+) -> None:
+ """Split fused QKV projection output into separate Q, K, V tensors.
+
+ This is a zero-allocation operation designed for CUDA Graph compatibility.
+ Output buffers must be pre-allocated.
+
+ Args:
+ qkv: Fused QKV tensor [seq_len, q_dim + k_dim + v_dim].
+ q_out: Pre-allocated Q output buffer [seq_len, q_dim] or [seq_len, n_heads, head_dim].
+ k_out: Pre-allocated K output buffer [seq_len, k_dim] or [seq_len, n_kv_heads, head_dim].
+ v_out: Pre-allocated V output buffer [seq_len, v_dim] or [seq_len, n_kv_heads, head_dim].
+ q_dim: Size of Q projection (num_heads * head_dim).
+ k_dim: Size of K projection (num_kv_heads * head_dim).
+ v_dim: Size of V projection (num_kv_heads * head_dim).
+
+ Note:
+ The output buffers can be 2D [seq_len, dim] or 3D [seq_len, heads, head_dim]
+ as long as the total size matches. The kernel writes linearly.
+ """
+ from pygpukit.core.backend import get_backend, get_native_module
+
+ backend = get_backend()
+ if not backend.is_available():
+ raise RuntimeError("split_qkv_batch requires GPU backend")
+
+ native = get_native_module()
+ native.split_qkv_batch(
+ qkv._get_native(),
+ q_out._get_native(),
+ k_out._get_native(),
+ v_out._get_native(),
+ q_dim,
+ k_dim,
+ v_dim,
+ )
+
+
+def slice_rows_range_ptr(
+ table: GPUArray,
+ out: GPUArray,
+ start_pos_buf: GPUArray,
+ count: int,
+) -> None:
+ """Slice consecutive rows from table using GPU-stored start position.
+
+ This is a zero-allocation operation designed for CUDA Graph compatibility.
+ The start position is read from a GPU buffer, enabling graph replay with
+ different positions without H2D copies.
+
+ Args:
+ table: Source table of shape [num_rows, row_dim].
+ out: Pre-allocated output buffer of shape [count, row_dim].
+ start_pos_buf: GPU buffer containing start position [1] int32.
+ count: Number of consecutive rows to copy.
+
+ Example:
+ # During CUDA Graph capture
+ slice_rows_range_ptr(rope_cos_table, cos_batch, start_pos_buf, batch_size)
+ # Copies cos_batch[i, :] = rope_cos_table[start_pos + i, :]
+ """
+ from pygpukit.core.backend import get_backend, get_native_module
+
+ backend = get_backend()
+ if not backend.is_available():
+ raise RuntimeError("slice_rows_range_ptr requires GPU backend")
+
+ native = get_native_module()
+ native.slice_rows_range_ptr(
+ table._get_native(),
+ out._get_native(),
+ start_pos_buf._get_native(),
+ count,
+ )
+
+
+__all__ = [
+ "bias_add_inplace",
+ "split_qkv_batch",
+ "slice_rows_range_ptr",
+]
diff --git a/src/pygpukit/ops/nn/norm.py b/src/pygpukit/ops/nn/norm.py
new file mode 100644
index 0000000..121a1aa
--- /dev/null
+++ b/src/pygpukit/ops/nn/norm.py
@@ -0,0 +1,224 @@
+"""Normalization layers for GPUArrays.
+
+Corresponds to native/ops/nn/norm/.
+"""
+
+from __future__ import annotations
+
+import numpy as np
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.backend import NativeBackend, get_backend
+from pygpukit.core.factory import from_numpy
+from pygpukit.ops._common import _validate_float_dtype
+
+
+def layernorm(
+ input: GPUArray,
+ gamma: GPUArray,
+ beta: GPUArray,
+ eps: float = 1e-5,
+) -> GPUArray:
+ """Layer normalization.
+
+ Computes: (x - mean) / sqrt(var + eps) * gamma + beta
+
+ Args:
+ input: Input array of shape [batch, features] or [batch, seq_len, features].
+ gamma: Scale parameter of shape [features].
+ beta: Bias parameter of shape [features].
+ eps: Small epsilon for numerical stability.
+
+ Returns:
+ A new GPUArray containing the normalized output.
+
+ Raises:
+ ValueError: If shapes or dtypes don't match.
+ """
+ _validate_float_dtype(input, "layernorm")
+
+ if input.ndim not in (2, 3):
+ raise ValueError(f"layernorm expects 2D or 3D input, got {input.ndim}D")
+ if gamma.ndim != 1 or beta.ndim != 1:
+ raise ValueError("layernorm expects 1D gamma and beta")
+ if input.dtype != gamma.dtype or input.dtype != beta.dtype:
+ raise ValueError("layernorm: all inputs must have same dtype")
+
+ features = input.shape[-1] # Last dimension is features
+ if gamma.shape[0] != features or beta.shape[0] != features:
+ raise ValueError(
+ f"layernorm: gamma/beta size {gamma.shape[0]} must match features {features}"
+ )
+
+ # Handle 3D input by reshaping to 2D, processing, and reshaping back
+ if input.ndim == 3:
+ batch, seq_len, feat = input.shape
+ input_2d = input.reshape(batch * seq_len, feat)
+ result_2d = _layernorm_dispatch(input_2d, gamma, beta, eps)
+ return result_2d.reshape(batch, seq_len, feat)
+ else:
+ return _layernorm_dispatch(input, gamma, beta, eps)
+
+
+def _layernorm_dispatch(
+ input: GPUArray,
+ gamma: GPUArray,
+ beta: GPUArray,
+ eps: float,
+) -> GPUArray:
+ """Dispatch layernorm to native or CPU implementation."""
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ return _layernorm_native(input, gamma, beta, eps)
+ else:
+ return _layernorm_cpu(input, gamma, beta, eps)
+
+
+def _layernorm_cpu(
+ input: GPUArray,
+ gamma: GPUArray,
+ beta: GPUArray,
+ eps: float,
+) -> GPUArray:
+ """CPU implementation of layernorm."""
+ x = input.to_numpy()
+ g = gamma.to_numpy()
+ b = beta.to_numpy()
+
+ # Compute mean and variance along features axis
+ mean = x.mean(axis=1, keepdims=True)
+ var = x.var(axis=1, keepdims=True)
+
+ # Normalize
+ normalized = (x - mean) / np.sqrt(var + eps)
+
+ # Apply affine transform
+ result = normalized * g + b
+ return from_numpy(result)
+
+
+def _layernorm_native(
+ input: GPUArray,
+ gamma: GPUArray,
+ beta: GPUArray,
+ eps: float,
+) -> GPUArray:
+ """Native C++ CUDA implementation of layernorm (zero-copy)."""
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ input_native = input._get_native()
+ gamma_native = gamma._get_native()
+ beta_native = beta._get_native()
+ c_native = native.layernorm(input_native, gamma_native, beta_native, eps)
+ return GPUArray._wrap_native(c_native)
+
+
+def rmsnorm(
+ input: GPUArray,
+ gamma: GPUArray,
+ eps: float = 1e-5,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """RMS Normalization (Root Mean Square Normalization).
+
+ Computes: x / sqrt(mean(x^2) + eps) * gamma
+
+ Simpler than LayerNorm (no mean subtraction, no beta).
+ Used in Llama and other modern LLMs.
+
+ Args:
+ input: Input array of shape [batch, features].
+ gamma: Scale parameter of shape [features].
+ eps: Small epsilon for numerical stability.
+ out: Optional output buffer. If provided, result is written in-place
+ (for CUDA Graph capture).
+
+ Returns:
+ A new GPUArray containing the normalized output (or out if provided).
+
+ Raises:
+ ValueError: If shapes or dtypes don't match.
+ """
+ _validate_float_dtype(input, "rmsnorm")
+
+ if input.ndim != 2:
+ raise ValueError(f"rmsnorm expects 2D input [batch, features], got {input.ndim}D")
+ if gamma.ndim != 1:
+ raise ValueError("rmsnorm expects 1D gamma")
+ if input.dtype != gamma.dtype:
+ raise ValueError("rmsnorm: all inputs must have same dtype")
+
+ features = input.shape[1]
+ if gamma.shape[0] != features:
+ raise ValueError(f"rmsnorm: gamma size {gamma.shape[0]} must match features {features}")
+
+ # Validate out array if provided
+ if out is not None:
+ if out.shape != input.shape:
+ raise ValueError(f"out shape {out.shape} does not match input shape {input.shape}")
+ if out.dtype != input.dtype:
+ raise ValueError(f"out dtype {out.dtype} does not match input dtype {input.dtype}")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ return _rmsnorm_native(input, gamma, eps, out=out)
+ else:
+ return _rmsnorm_cpu(input, gamma, eps, out=out)
+
+
+def _rmsnorm_cpu(
+ input: GPUArray,
+ gamma: GPUArray,
+ eps: float,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """CPU implementation of rmsnorm."""
+ x = input.to_numpy()
+ g = gamma.to_numpy()
+
+ # RMS = sqrt(mean(x^2) + eps)
+ rms = np.sqrt(np.mean(x**2, axis=1, keepdims=True) + eps)
+
+ # Normalize and scale
+ result = (x / rms) * g
+
+ if out is not None:
+ out_np = out.to_numpy()
+ np.copyto(out_np, result)
+ out._data = from_numpy(out_np)._data
+ return out
+ return from_numpy(result)
+
+
+def _rmsnorm_native(
+ input: GPUArray,
+ gamma: GPUArray,
+ eps: float,
+ *,
+ out: GPUArray | None = None,
+) -> GPUArray:
+ """Native C++ CUDA implementation of rmsnorm (zero-copy)."""
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ input_native = input._get_native()
+ gamma_native = gamma._get_native()
+
+ if out is not None:
+ out_native = out._get_native()
+ native.rmsnorm_(input_native, gamma_native, out_native, eps)
+ return out
+ else:
+ c_native = native.rmsnorm(input_native, gamma_native, eps)
+ return GPUArray._wrap_native(c_native)
+
+
+__all__ = [
+ "layernorm",
+ "rmsnorm",
+]
diff --git a/src/pygpukit/ops/nn/recurrent.py b/src/pygpukit/ops/nn/recurrent.py
new file mode 100644
index 0000000..f8ddeb4
--- /dev/null
+++ b/src/pygpukit/ops/nn/recurrent.py
@@ -0,0 +1,140 @@
+"""Recurrent (LSTM) operations for GPUArrays.
+
+Corresponds to native/ops/nn/recurrent/.
+"""
+
+from __future__ import annotations
+
+from pygpukit.core.array import GPUArray
+
+
+def lstm_forward(
+ x: GPUArray,
+ W_ih: GPUArray,
+ W_hh: GPUArray,
+ b_ih: GPUArray,
+ b_hh: GPUArray,
+ h0: GPUArray | None = None,
+ c0: GPUArray | None = None,
+ reverse: bool = False,
+) -> tuple[GPUArray, GPUArray, GPUArray]:
+ """LSTM forward pass (unidirectional).
+
+ Implements the standard LSTM equations:
+ i_t = sigmoid(W_ii @ x_t + b_ii + W_hi @ h_{t-1} + b_hi)
+ f_t = sigmoid(W_if @ x_t + b_if + W_hf @ h_{t-1} + b_hf)
+ g_t = tanh(W_ig @ x_t + b_ig + W_hg @ h_{t-1} + b_hg)
+ o_t = sigmoid(W_io @ x_t + b_io + W_ho @ h_{t-1} + b_ho)
+ c_t = f_t * c_{t-1} + i_t * g_t
+ h_t = o_t * tanh(c_t)
+
+ Args:
+ x: Input sequence [batch, seq_len, input_size].
+ W_ih: Input-to-hidden weights [4*hidden_size, input_size].
+ W_hh: Hidden-to-hidden weights [4*hidden_size, hidden_size].
+ b_ih: Input bias [4*hidden_size].
+ b_hh: Hidden bias [4*hidden_size].
+ h0: Initial hidden state [batch, hidden_size]. If None, zeros.
+ c0: Initial cell state [batch, hidden_size]. If None, zeros.
+ reverse: If True, process sequence in reverse order.
+
+ Returns:
+ Tuple of (output, h_n, c_n):
+ output: Hidden states [batch, seq_len, hidden_size]
+ h_n: Final hidden state [batch, hidden_size]
+ c_n: Final cell state [batch, hidden_size]
+ """
+ from pygpukit.core.backend import get_backend, get_native_module
+
+ backend = get_backend()
+ if not backend.is_available():
+ raise RuntimeError("lstm_forward requires GPU backend")
+
+ native = get_native_module()
+
+ # Create zero-sized arrays for None states
+ if h0 is None:
+ h0_native = native.GPUArray([0], native.Float32)
+ else:
+ h0_native = h0._get_native()
+
+ if c0 is None:
+ c0_native = native.GPUArray([0], native.Float32)
+ else:
+ c0_native = c0._get_native()
+
+ output_native, h_n_native, c_n_native = native.lstm_forward(
+ x._get_native(),
+ W_ih._get_native(),
+ W_hh._get_native(),
+ b_ih._get_native(),
+ b_hh._get_native(),
+ h0_native,
+ c0_native,
+ reverse,
+ )
+
+ return (
+ GPUArray._wrap_native(output_native),
+ GPUArray._wrap_native(h_n_native),
+ GPUArray._wrap_native(c_n_native),
+ )
+
+
+def lstm_bidirectional(
+ x: GPUArray,
+ W_ih_fwd: GPUArray,
+ W_hh_fwd: GPUArray,
+ b_ih_fwd: GPUArray,
+ b_hh_fwd: GPUArray,
+ W_ih_bwd: GPUArray,
+ W_hh_bwd: GPUArray,
+ b_ih_bwd: GPUArray,
+ b_hh_bwd: GPUArray,
+) -> tuple[GPUArray, GPUArray, GPUArray]:
+ """Bidirectional LSTM.
+
+ Runs forward and backward LSTM passes and concatenates the outputs.
+
+ Args:
+ x: Input sequence [batch, seq_len, input_size].
+ W_ih_fwd, W_hh_fwd, b_ih_fwd, b_hh_fwd: Forward LSTM weights.
+ W_ih_bwd, W_hh_bwd, b_ih_bwd, b_hh_bwd: Backward LSTM weights.
+
+ Returns:
+ Tuple of (output, h_n, c_n):
+ output: Concatenated hidden states [batch, seq_len, 2*hidden_size]
+ h_n: Stacked final hidden states [2, batch, hidden_size]
+ c_n: Stacked final cell states [2, batch, hidden_size]
+ """
+ from pygpukit.core.backend import get_backend, get_native_module
+
+ backend = get_backend()
+ if not backend.is_available():
+ raise RuntimeError("lstm_bidirectional requires GPU backend")
+
+ native = get_native_module()
+
+ output_native, h_n_native, c_n_native = native.lstm_bidirectional(
+ x._get_native(),
+ W_ih_fwd._get_native(),
+ W_hh_fwd._get_native(),
+ b_ih_fwd._get_native(),
+ b_hh_fwd._get_native(),
+ W_ih_bwd._get_native(),
+ W_hh_bwd._get_native(),
+ b_ih_bwd._get_native(),
+ b_hh_bwd._get_native(),
+ )
+
+ return (
+ GPUArray._wrap_native(output_native),
+ GPUArray._wrap_native(h_n_native),
+ GPUArray._wrap_native(c_n_native),
+ )
+
+
+__all__ = [
+ "lstm_forward",
+ "lstm_bidirectional",
+]
diff --git a/src/pygpukit/ops/nn/rope.py b/src/pygpukit/ops/nn/rope.py
new file mode 100644
index 0000000..0c81a2f
--- /dev/null
+++ b/src/pygpukit/ops/nn/rope.py
@@ -0,0 +1,136 @@
+"""RoPE (Rotary Position Embedding) operations for GPUArrays.
+
+Corresponds to native/ops/nn/rope/.
+"""
+
+from __future__ import annotations
+
+from pygpukit.core.array import GPUArray
+from pygpukit.core.backend import NativeBackend, get_backend
+from pygpukit.core.factory import from_numpy
+from pygpukit.ops._common import _validate_float_dtype
+
+
+def rope_inplace(
+ q: GPUArray,
+ k: GPUArray,
+ cos: GPUArray,
+ sin: GPUArray,
+) -> None:
+ """Apply Rotary Position Embedding (RoPE) to Q and K tensors in-place.
+
+ Args:
+ q: Query tensor of shape [seq_len, n_heads_q, head_dim] (modified in-place).
+ k: Key tensor of shape [seq_len, n_heads_k, head_dim] (modified in-place).
+ cos: Precomputed cosine of shape [seq_len, head_dim].
+ sin: Precomputed sine of shape [seq_len, head_dim].
+
+ Note:
+ This operation modifies q and k in-place.
+ Works with GQA (n_heads_k can be different from n_heads_q).
+ """
+ _validate_float_dtype(q, "rope_inplace")
+
+ if q.ndim != 3 or k.ndim != 3:
+ raise ValueError("rope_inplace expects 3D q, k [seq_len, n_heads, head_dim]")
+ if cos.ndim != 2 or sin.ndim != 2:
+ raise ValueError("rope_inplace expects 2D cos, sin [seq_len, head_dim]")
+
+ backend = get_backend()
+
+ if isinstance(backend, NativeBackend) and backend.is_available():
+ _rope_inplace_native(q, k, cos, sin)
+ else:
+ _rope_inplace_cpu(q, k, cos, sin)
+
+
+def _rope_inplace_cpu(
+ q: GPUArray,
+ k: GPUArray,
+ cos: GPUArray,
+ sin: GPUArray,
+) -> None:
+ """CPU implementation of rope_inplace."""
+
+ q_np = q.to_numpy()
+ k_np = k.to_numpy()
+ cos_np = cos.to_numpy()
+ sin_np = sin.to_numpy()
+
+ seq_len, n_heads_q, head_dim = q_np.shape
+ n_heads_k = k_np.shape[1]
+ half_dim = head_dim // 2
+
+ # Apply RoPE to Q
+ for s in range(seq_len):
+ c = cos_np[s, :half_dim]
+ sn = sin_np[s, :half_dim]
+ for h in range(n_heads_q):
+ q0 = q_np[s, h, :half_dim].copy()
+ q1 = q_np[s, h, half_dim:].copy()
+ q_np[s, h, :half_dim] = q0 * c - q1 * sn
+ q_np[s, h, half_dim:] = q1 * c + q0 * sn
+
+ # Apply RoPE to K
+ for s in range(seq_len):
+ c = cos_np[s, :half_dim]
+ sn = sin_np[s, :half_dim]
+ for h in range(n_heads_k):
+ k0 = k_np[s, h, :half_dim].copy()
+ k1 = k_np[s, h, half_dim:].copy()
+ k_np[s, h, :half_dim] = k0 * c - k1 * sn
+ k_np[s, h, half_dim:] = k1 * c + k0 * sn
+
+ # Update the GPUArray data in-place
+ q._data = from_numpy(q_np)._data
+ k._data = from_numpy(k_np)._data
+
+
+def _rope_inplace_native(
+ q: GPUArray,
+ k: GPUArray,
+ cos: GPUArray,
+ sin: GPUArray,
+) -> None:
+ """Native C++ CUDA implementation of rope_inplace."""
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ q_native = q._get_native()
+ k_native = k._get_native()
+ cos_native = cos._get_native()
+ sin_native = sin._get_native()
+ native.rope_inplace(q_native, k_native, cos_native, sin_native)
+
+
+def rope_inplace_f32table(
+ q: GPUArray,
+ k: GPUArray,
+ cos: GPUArray,
+ sin: GPUArray,
+) -> None:
+ """Apply RoPE with FP32 cos/sin tables (higher precision for bf16/f16).
+
+ Uses FP32 cos/sin tables for higher precision computation, avoiding
+ the need to convert tables to bf16/f16.
+
+ Args:
+ q: Query tensor [seq_len, n_heads_q, head_dim] (bf16 or f16, modified in-place).
+ k: Key tensor [seq_len, n_heads_k, head_dim] (bf16 or f16, modified in-place).
+ cos: Precomputed cosine [seq_len, head_dim] (f32).
+ sin: Precomputed sine [seq_len, head_dim] (f32).
+ """
+ from pygpukit.core.backend import get_native_module
+
+ native = get_native_module()
+ q_native = q._get_native()
+ k_native = k._get_native()
+ cos_native = cos._get_native()
+ sin_native = sin._get_native()
+ native.rope_inplace_f32table(q_native, k_native, cos_native, sin_native)
+
+
+__all__ = [
+ "rope_inplace",
+ "rope_inplace_f32table",
+]