diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 92c18c5..5ca19d9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,6 +69,9 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 1 # Shallow clone for faster checkout - name: Set up Python uses: actions/setup-python@v5 diff --git a/.github/workflows/publish-testpypi.yml b/.github/workflows/publish-testpypi.yml index 4e00809..959badc 100644 --- a/.github/workflows/publish-testpypi.yml +++ b/.github/workflows/publish-testpypi.yml @@ -19,6 +19,9 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 1 # Shallow clone for faster checkout - name: Set up Python uses: actions/setup-python@v5 @@ -60,6 +63,9 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 1 # Shallow clone for faster checkout - name: Set up Python 3.12 uses: actions/setup-python@v5 @@ -136,6 +142,9 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 1 # Shallow clone for faster checkout - name: Set up Python 3.12 shell: pwsh diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index cc526ff..ea7c4ad 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -18,6 +18,9 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 1 # Shallow clone for faster checkout - name: Set up Python uses: actions/setup-python@v5 @@ -44,6 +47,9 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 1 # Shallow clone for faster checkout - name: Set up Python 3.12 uses: actions/setup-python@v5 @@ -124,6 +130,9 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 1 # Shallow clone for faster checkout - name: Set up Python 3.12 shell: pwsh @@ -289,6 +298,9 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 1 # Shallow clone for faster checkout - name: Download all artifacts uses: actions/download-artifact@v4 diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..281cb2d --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/cutlass"] + path = third_party/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/README.md b/README.md index 556d202..157cdf9 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,89 @@ PyGPUkit aims to be the "micro-runtime for GPU computing": small, fast, and idea --- +## What's New in v0.2.6 + +### CUTLASS Backend (Default) +NVIDIA CUTLASS v4.3.0 is now the default GEMM backend, delivering optimized TensorCore performance out of the box. + +| Feature | Description | +|---------|-------------| +| **TF32 TensorCore** | 31+ TFLOPS for FP32 inputs (automatic) | +| **FP16 TensorCore** | 63 TFLOPS | +| **BF16 TensorCore** | 63 TFLOPS | +| **Zero Config** | No environment variables needed | + +```python +import pygpukit as gpk +import numpy as np + +# CUTLASS TF32 is automatic for FP32 (31+ TFLOPS) +a = gpk.from_numpy(np.random.randn(8192, 8192).astype(np.float32)) +b = gpk.from_numpy(np.random.randn(8192, 8192).astype(np.float32)) +c = a @ b # Uses CUTLASS TF32 TensorCore + +# For full FP32 precision (no TF32), set: +# PYGPUKIT_NO_TF32=1 +``` + +### Multi-LLM Concurrent Execution +Run multiple AI models (LLM, TTS, Vision) concurrently on a single GPU with independent CUDA streams and VRAM budgets. + +| Feature | Description | +|---------|-------------| +| **Execution Control** | User controls execution order | +| **Stream Isolation** | No implicit sync between streams | +| **VRAM Budgeting** | Safe memory sharing per model | +| **Concurrent Safety** | "Running simultaneously doesn't break" | +| **asyncio Integration** | Native Python async/await support | + +> **Note:** On a single GPU, Multi-LLM scheduling enables **concurrent execution, not faster execution**, for compute-bound workloads. Speedup benefits apply to I/O-bound workloads or multi-GPU setups. + +```python +import asyncio +from pygpukit.scheduler import ( + create_context, context_session, GB, initialize +) + +# Create execution contexts with VRAM budgets +initialize(device_id=0) +llm_ctx = create_context("llm", max_vram=4 * GB) +tts_ctx = create_context("tts", max_vram=2 * GB) + +async def run_parallel(): + async with context_session(llm_ctx), context_session(tts_ctx): + # Run models concurrently with asyncio.gather + llm_task = asyncio.create_task(run_llm_inference()) + tts_task = asyncio.create_task(run_tts_synthesis()) + + text, audio = await asyncio.gather(llm_task, tts_task) + return text, audio + +result = asyncio.run(run_parallel()) +``` + +### FP16/BF16 TensorCore (via CUTLASS) +| Feature | Description | +|---------|-------------| +| **FP16 TensorCore** | 63 TFLOPS (automatic via CUTLASS) | +| **BF16 TensorCore** | 63 TFLOPS (automatic via CUTLASS) | +| **FP32 Accumulation** | Numerical stability maintained | + +```python +import pygpukit as gpk +import numpy as np + +# FP16 TensorCore matmul (63 TFLOPS on RTX 3090 Ti) +# No environment variable needed - CUTLASS is automatic +a = gpk.from_numpy(np.random.randn(8192, 8192).astype(np.float16)) +b = gpk.from_numpy(np.random.randn(8192, 8192).astype(np.float16)) +c = a @ b # Uses CUTLASS TensorCore +``` + +> **Note:** CUTLASS requires matrix dimensions divisible by 16. + +--- + ## What's New in v0.2.5 ### FP16 / BF16 Support @@ -99,23 +182,23 @@ print(f"NVRTC Path: {gp.get_nvrtc_path()}") # Path to NVRTC DLL (if available) ### Benchmark Comparison (RTX 3090 Ti, 8192×8192) -| Library | FP32 | TF32 | Requirements | -|---------|------|------|--------------| -| **NumPy** (OpenBLAS) | ~0.8 TFLOPS | — | CPU only | -| **cuBLAS** | ~21 TFLOPS | ~59 TFLOPS | CUDA Toolkit | -| **PyGPUkit** | 16.7 TFLOPS | 29.7 TFLOPS | GPU drivers only | +| Library | FP32 | TF32 | FP16 | BF16 | Requirements | +|---------|------|------|------|------|--------------| +| **NumPy** (OpenBLAS) | ~0.8 TFLOPS | — | — | — | CPU only | +| **cuBLAS** | ~21 TFLOPS | ~59 TFLOPS | ~75 TFLOPS | ~83 TFLOPS | CUDA Toolkit | +| **PyGPUkit** (CUTLASS) | 18 TFLOPS | **31 TFLOPS** | **63 TFLOPS** | **63 TFLOPS** | GPU drivers only | > Built-in matmul kernels are pre-compiled. Driver-Only and Full (JIT) modes have identical matmul performance. JIT is only needed for custom kernels. ### PyGPUkit Performance by Matrix Size -| Matrix Size | FP32 | TF32 | FP16 | BF16 | -|-------------|------|------|------|------| -| 2048×2048 | 9.6 TFLOPS | 13.2 TFLOPS | 2.4 TFLOPS | 2.4 TFLOPS | -| 4096×4096 | 14.7 TFLOPS | 22.8 TFLOPS | 2.4 TFLOPS | 2.3 TFLOPS | -| 8192×8192 | 16.7 TFLOPS | 29.7 TFLOPS | 2.3 TFLOPS | 2.3 TFLOPS | +| Matrix Size | FP32 (NO_TF32) | TF32 (CUTLASS) | FP16 (CUTLASS) | BF16 (CUTLASS) | +|-------------|----------------|----------------|----------------|----------------| +| 2048×2048 | 9.6 TFLOPS | 13 TFLOPS | 15 TFLOPS | 21 TFLOPS | +| 4096×4096 | 14.7 TFLOPS | 22 TFLOPS | 44 TFLOPS | 44 TFLOPS | +| 8192×8192 | 18 TFLOPS | **31 TFLOPS** | **63 TFLOPS** | **63 TFLOPS** | -> **Note:** FP16/BF16 matmul uses simple kernels with FP32 accumulation. TensorCore optimization planned for future releases (see [Issue #60](https://github.com/m96-chan/PyGPUkit/issues/60)). +> **Note:** CUTLASS is automatic for compatible sizes (16-aligned). Use `PYGPUKIT_NO_TF32=1` for full FP32 precision. --- @@ -227,6 +310,8 @@ manager.create_partition("inference", "Inference", | **QoS Policy** | Guaranteed/Burstable/BestEffort tiers | | **Kernel Pacing** | Bandwidth-based throttling per stream | | **GPU Partitioning** | Resource isolation, multi-tenant support | +| **Multi-LLM Execution** | Concurrent AI model execution with stream isolation | +| **asyncio Integration** | Native Python async/await for concurrent inference | --- @@ -265,12 +350,12 @@ PyGPUkit/ | **v0.2.3** | TF32 TensorCore (PTX mma.sync), 28 TFLOPS | | **v0.2.4** | **Single-binary distribution**, dynamic NVRTC, driver-only mode | | **v0.2.5** | **FP16/BF16 support**, reduction ops, operator overloads, TF32 v2 (~30 TFLOPS) | +| **v0.2.6** | **CUTLASS backend** (31 TFLOPS TF32, 63 TFLOPS FP16/BF16), Multi-LLM concurrent execution | ### Planned | Version | Goals | |---------|-------| -| **v0.2.6** | FP16/BF16 TensorCore optimization, Multi-GPU detection | | **v0.2.7** | Full API review, documentation, backward compatibility | | **v0.3** | Triton backend, advanced ops (softmax, layernorm), MPS/MIG | diff --git a/examples/demo_v026_multi_llm.py b/examples/demo_v026_multi_llm.py new file mode 100644 index 0000000..7c72604 --- /dev/null +++ b/examples/demo_v026_multi_llm.py @@ -0,0 +1,405 @@ +#!/usr/bin/env python3 +""" +PyGPUkit v0.2.6 Multi-LLM Async Execution Demo + +Demonstrates running multiple LLM-like workloads concurrently on a single GPU +using PyGPUkit's native LLM module (GPT2Model with MLP blocks). + +Each workload runs on a separate CUDA stream with independent VRAM budgets. +Uses Python asyncio for non-blocking parallel execution. + +Key differences from PyTorch-based demo: +- Uses PyGPUkit's native matmul (CUTLASS TF32) +- Uses PyGPUkit's native layernorm, gelu +- Real transformer block structure (LayerNorm -> MLP -> Residual) +""" + +from __future__ import annotations + +import asyncio +import time +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + pass + +# Check if multi-LLM scheduler is available +try: + from pygpukit.scheduler import ( + GB, + HAS_MULTI_LLM, + MB, + context_session, + create_context, + destroy_context, + initialize, + reset, + stats, + ) +except ImportError: + HAS_MULTI_LLM = False + +# Check if GPU operations are available +try: + import pygpukit as gpk + from pygpukit.llm import MLP, LayerNorm, TransformerBlock + + HAS_GPU = True +except ImportError: + HAS_GPU = False + + +def section(title: str) -> None: + """Print section header.""" + print() + print("=" * 70) + print(f" {title}") + print("=" * 70) + + +# ============================================================================= +# Real LLM Workloads using PyGPUkit's GPT2Model +# ============================================================================= + + +class PyGPUkitLLM: + """LLM using PyGPUkit's native GPT2Model structure.""" + + def __init__( + self, + name: str, + n_embd: int = 768, + n_layer: int = 6, + n_inner: int | None = None, + ): + self.name = name + self.n_embd = n_embd + self.n_layer = n_layer + self.n_inner = n_inner or (4 * n_embd) + self.blocks: list[TransformerBlock] = [] + self.ln_f: LayerNorm | None = None + + def load_weights(self) -> None: + """Initialize random weights (simulating model loading).""" + if not HAS_GPU: + return + + print( + f" [{self.name}] Loading GPT2-style model (embd={self.n_embd}, layers={self.n_layer})" + ) + + # Create transformer blocks with random weights + self.blocks = [] + for i in range(self.n_layer): + # LayerNorm weights + ln_weight = gpk.from_numpy(np.ones(self.n_embd, dtype=np.float32)) + ln_bias = gpk.from_numpy(np.zeros(self.n_embd, dtype=np.float32)) + + # MLP weights: fc1 [n_inner, n_embd], fc2 [n_embd, n_inner] + c_fc_weight = gpk.from_numpy( + (np.random.randn(self.n_inner, self.n_embd) * 0.02).astype(np.float32) + ) + c_fc_bias = gpk.from_numpy(np.zeros(self.n_inner, dtype=np.float32)) + c_proj_weight = gpk.from_numpy( + (np.random.randn(self.n_embd, self.n_inner) * 0.02).astype(np.float32) + ) + c_proj_bias = gpk.from_numpy(np.zeros(self.n_embd, dtype=np.float32)) + + mlp = MLP(c_fc_weight, c_fc_bias, c_proj_weight, c_proj_bias) + block = TransformerBlock(ln_weight, ln_bias, mlp) + self.blocks.append(block) + + # Final LayerNorm + self.ln_f = LayerNorm( + gpk.from_numpy(np.ones(self.n_embd, dtype=np.float32)), + gpk.from_numpy(np.zeros(self.n_embd, dtype=np.float32)), + ) + + # Calculate model size + params = self.n_layer * ( + self.n_embd # ln weight + + self.n_embd # ln bias + + self.n_inner * self.n_embd # c_fc weight + + self.n_inner # c_fc bias + + self.n_embd * self.n_inner # c_proj weight + + self.n_embd # c_proj bias + ) + print(f" [{self.name}] Parameters: {params / 1e6:.1f}M") + + def forward(self, batch_size: int = 128, seq_len: int = 512) -> np.ndarray: + """Run forward pass through transformer blocks. + + This simulates the MLP portion of transformer inference. + Each block: LayerNorm -> MLP (fc1 -> gelu -> fc2) -> Residual + """ + if not HAS_GPU or not self.blocks: + time.sleep(0.1) + return np.zeros((batch_size, self.n_embd), dtype=np.float32) + + # Create input hidden states [batch_size, n_embd] + hidden = gpk.from_numpy(np.random.randn(batch_size, self.n_embd).astype(np.float32) * 0.1) + + # Apply transformer blocks + for block in self.blocks: + hidden = block(hidden) + + # Final LayerNorm + if self.ln_f: + hidden = self.ln_f(hidden) + + return hidden.to_numpy() + + +# ============================================================================= +# Demo Functions +# ============================================================================= + + +def demo_sequential() -> float: + """Run workloads sequentially (baseline).""" + section("Sequential Execution (Baseline)") + + # Create models with different sizes (simulating different LLMs) + llm_large = PyGPUkitLLM("llm-large", n_embd=1024, n_layer=12) # ~50M MLP params + llm_medium = PyGPUkitLLM("llm-medium", n_embd=768, n_layer=6) # ~14M MLP params + llm_small = PyGPUkitLLM("llm-small", n_embd=512, n_layer=4) # ~4M MLP params + + print("\nLoading models...") + llm_large.load_weights() + llm_medium.load_weights() + llm_small.load_weights() + + # Warmup + print("\nWarmup...") + llm_large.forward(batch_size=64) + llm_medium.forward(batch_size=64) + llm_small.forward(batch_size=64) + + print("\nRunning sequentially (3 iterations)...") + times = [] + for i in range(3): + start = time.perf_counter() + + # Run one after another (simulating sequential inference requests) + result_large = llm_large.forward(batch_size=128) + result_medium = llm_medium.forward(batch_size=128) + result_small = llm_small.forward(batch_size=128) + + elapsed = time.perf_counter() - start + times.append(elapsed) + print(f" Iteration {i + 1}: {elapsed * 1000:.2f} ms") + + avg_elapsed = sum(times) / len(times) + + print("\nResults:") + print(f" Large output shape: {result_large.shape}") + print(f" Medium output shape: {result_medium.shape}") + print(f" Small output shape: {result_small.shape}") + print(f"\n Average time: {avg_elapsed * 1000:.2f} ms") + + return avg_elapsed + + +async def demo_parallel_async() -> float: + """Run workloads in parallel using asyncio.""" + section("Parallel Async Execution (v0.2.6)") + + if not HAS_MULTI_LLM: + print("\n [SKIP] Multi-LLM scheduler not available") + print(" Rebuild PyGPUkit with Rust backend to enable") + return 0.0 + + # Initialize scheduler + initialize(device_id=0) + + # Create execution contexts with VRAM budgets + print("\nCreating execution contexts...") + ctx_large = create_context("llm-large", max_vram=4 * GB) + ctx_medium = create_context("llm-medium", max_vram=2 * GB) + ctx_small = create_context("llm-small", max_vram=1 * GB) + + print( + f" Large context: stream_id={ctx_large.stream_id}, max_vram={ctx_large.max_vram / GB:.1f} GB" + ) + print( + f" Medium context: stream_id={ctx_medium.stream_id}, max_vram={ctx_medium.max_vram / GB:.1f} GB" + ) + print( + f" Small context: stream_id={ctx_small.stream_id}, max_vram={ctx_small.max_vram / GB:.1f} GB" + ) + + # Create models + llm_large = PyGPUkitLLM("llm-large", n_embd=1024, n_layer=12) + llm_medium = PyGPUkitLLM("llm-medium", n_embd=768, n_layer=6) + llm_small = PyGPUkitLLM("llm-small", n_embd=512, n_layer=4) + + print("\nLoading models...") + llm_large.load_weights() + llm_medium.load_weights() + llm_small.load_weights() + + # Define async workloads + async def run_large() -> np.ndarray: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: llm_large.forward(batch_size=128)) + + async def run_medium() -> np.ndarray: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: llm_medium.forward(batch_size=128)) + + async def run_small() -> np.ndarray: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: llm_small.forward(batch_size=128)) + + # Warmup + print("\nWarmup...") + async with context_session(ctx_large), context_session(ctx_medium), context_session(ctx_small): + await asyncio.gather(run_large(), run_medium(), run_small()) + + print("\nRunning in parallel (3 iterations)...") + times = [] + for i in range(3): + start = time.perf_counter() + + async with ( + context_session(ctx_large), + context_session(ctx_medium), + context_session(ctx_small), + ): + result_large, result_medium, result_small = await asyncio.gather( + run_large(), + run_medium(), + run_small(), + ) + + elapsed = time.perf_counter() - start + times.append(elapsed) + print(f" Iteration {i + 1}: {elapsed * 1000:.2f} ms") + + avg_elapsed = sum(times) / len(times) + + print("\nResults:") + print(f" Large output shape: {result_large.shape}") + print(f" Medium output shape: {result_medium.shape}") + print(f" Small output shape: {result_small.shape}") + print(f"\n Average time: {avg_elapsed * 1000:.2f} ms") + + # Show scheduler stats + s = stats() + print("\nScheduler stats:") + print(f" Contexts: {s.context_count}") + print(f" VRAM used: {s.used_vram / MB:.1f} MB") + + # Cleanup + destroy_context("llm-large") + destroy_context("llm-medium") + destroy_context("llm-small") + + return avg_elapsed + + +def demo_context_session_api(): + """Demonstrate the context_session API.""" + section("Context Session API Demo") + + if not HAS_MULTI_LLM: + print("\n [SKIP] Multi-LLM scheduler not available") + return + + reset() # Clean slate + initialize(device_id=0) + + print("\nTarget API pattern:") + print(""" + async with context_session(llm_ctx), context_session(tts_ctx): + llm_f = llm_ctx.dispatch_async(llm_req) + tts_f = tts_ctx.dispatch_async(tts_req) + text, audio = await asyncio.gather(llm_f, tts_f) + """) + + # Create contexts + ctx1 = create_context("model_a", max_vram=2 * GB) + ctx2 = create_context("model_b", max_vram=2 * GB) + + print("Sync usage (with statement):") + print(" with context_session(ctx1), context_session(ctx2):") + + with context_session(ctx1), context_session(ctx2): + print(f" ctx1.is_session_active() = {ctx1.is_session_active()}") + print(f" ctx2.is_session_active() = {ctx2.is_session_active()}") + + print(" After exiting:") + print(f" ctx1.is_session_active() = {ctx1.is_session_active()}") + print(f" ctx2.is_session_active() = {ctx2.is_session_active()}") + + # Cleanup + reset() + + +def demo_speedup_comparison(): + """Compare sequential vs parallel execution times.""" + section("Speedup Comparison") + + if not HAS_GPU: + print("\n [SKIP] GPU not available, speedup demo requires GPU") + return + + # Run sequential + seq_time = demo_sequential() + + # Run parallel + par_time = asyncio.run(demo_parallel_async()) + + if par_time > 0: + section("Summary") + print(f"\n Sequential: {seq_time * 1000:.2f} ms") + print(f" Parallel: {par_time * 1000:.2f} ms") + speedup = seq_time / par_time if par_time > 0 else 0 + print(f" Speedup: {speedup:.2f}x") + + if speedup < 1.0: + print("\n Note: Single-GPU parallel execution has overhead.") + print(" Speedup improves with:") + print(" - Multi-GPU setups (true parallelism)") + print(" - I/O-bound workloads (async overlapping)") + print(" - CPU preprocessing overlapping GPU compute") + + +def main(): + print("=" * 70) + print(" PyGPUkit v0.2.6 - Multi-LLM Async Execution Demo") + print(" Using native PyGPUkit LLM module (CUTLASS TF32 matmul)") + print("=" * 70) + + print("\nBackend status:") + print(f" GPU available: {HAS_GPU}") + print(f" Multi-LLM scheduler: {HAS_MULTI_LLM}") + + if HAS_GPU: + import pygpukit as gpk + + print(f" CUDA available: {gpk.is_cuda_available()}") + + if not HAS_GPU: + print("\n [WARNING] No GPU available, running in CPU simulation mode") + + # Demo the API + demo_context_session_api() + + # Run comparison + demo_speedup_comparison() + + section("Demo Complete") + print("\nPyGPUkit Multi-LLM features:") + print(" - Native GPT2-style transformer blocks") + print(" - CUTLASS TF32 TensorCore matmul (31+ TFLOPS)") + print(" - Native layernorm, gelu operations") + print(" - Separate CUDA streams per context") + print(" - Independent VRAM budgets") + print(" - asyncio-compatible execution") + + +if __name__ == "__main__": + main() diff --git a/native/CMakeLists.txt b/native/CMakeLists.txt index 500a941..b9903d6 100644 --- a/native/CMakeLists.txt +++ b/native/CMakeLists.txt @@ -21,6 +21,18 @@ find_package(pybind11 CONFIG REQUIRED) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CUDAToolkit_INCLUDE_DIRS}) +# CUTLASS (header-only library) +set(CUTLASS_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../third_party/cutlass") +if(EXISTS "${CUTLASS_DIR}/include") + message(STATUS "CUTLASS found at: ${CUTLASS_DIR}") + include_directories(${CUTLASS_DIR}/include) + include_directories(${CUTLASS_DIR}/tools/util/include) + add_definitions(-DPYGPUKIT_HAS_CUTLASS=1) +else() + message(STATUS "CUTLASS not found, using fallback kernels") + add_definitions(-DPYGPUKIT_HAS_CUTLASS=0) +endif() + # Set default CUDA architectures if not specified # PyGPUkit requires SM >= 80 (Ampere and newer) # Older architectures (Pascal/Turing) are NOT supported @@ -32,8 +44,8 @@ message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") # Ampere-optimized compiler flags # Add -v for verbose ptxas output to check register usage -# Limit registers to 128 to prevent spilling issues with WMMA kernels -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr --use_fast_math --ptxas-options=-v -maxrregcount=128") +# NOTE: Do NOT use -maxrregcount for CUTLASS - it needs many registers for optimal performance +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr --use_fast_math --ptxas-options=-v -O3") # Build single pybind11 module with all sources pybind11_add_module(_pygpukit_native @@ -48,8 +60,13 @@ pybind11_add_module(_pygpukit_native jit/compiler.cpp jit/kernel.cpp jit/nvrtc_loader.cpp - # Ops - ops/basic.cu + # Ops - Modular structure + ops/elementwise/elementwise.cu + ops/unary/unary.cu + ops/reduction/reduction.cu + ops/matmul/matmul.cu + ops/matmul/matmul_cutlass.cu + ops/nn/nn.cu # Bindings bindings/module.cpp bindings/core_bindings.cpp @@ -64,8 +81,11 @@ target_link_libraries(_pygpukit_native PRIVATE CUDA::cuda_driver ) +# IMPORTANT: Do NOT enable CUDA_SEPARABLE_COMPILATION +# It causes 15x performance degradation for CUTLASS kernels +# due to prevented inlining and indirect function calls set_target_properties(_pygpukit_native PROPERTIES - CUDA_SEPARABLE_COMPILATION ON + CUDA_SEPARABLE_COMPILATION OFF ) # Install the module to the correct location for scikit-build-core diff --git a/native/bindings/ops_bindings.cpp b/native/bindings/ops_bindings.cpp index 66a378f..48c6a0a 100644 --- a/native/bindings/ops_bindings.cpp +++ b/native/bindings/ops_bindings.cpp @@ -1,7 +1,7 @@ #include #include -#include "../ops/basic.cuh" +#include "../ops/ops.cuh" namespace py = pybind11; using namespace pygpukit; @@ -114,4 +114,23 @@ void init_ops_bindings(py::module_& m) { m.def("max", &ops::max, py::arg("a"), "Max of all elements (float32/float64 only), returns scalar GPUArray"); + + // ======================================================================== + // Neural Network operations + // ======================================================================== + + // GELU activation + m.def("gelu", &ops::gelu, + py::arg("input"), + "GELU activation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))"); + + // Bias add (in-place) + m.def("bias_add_inplace", &ops::bias_add_inplace, + py::arg("output"), py::arg("bias"), + "Add bias to output in-place: output[batch, features] += bias[features]"); + + // LayerNorm + m.def("layernorm", &ops::layernorm, + py::arg("input"), py::arg("gamma"), py::arg("beta"), py::arg("eps") = 1e-5f, + "Layer normalization: (x - mean) / sqrt(var + eps) * gamma + beta"); } diff --git a/native/ops/basic.cu b/native/ops/basic.cu index a0a066e..1b00c7c 100644 --- a/native/ops/basic.cu +++ b/native/ops/basic.cu @@ -6,6 +6,8 @@ #include "matmul_f32_tf32.cuh" #include "matmul_f32_tf32_v2.cuh" #include "matmul_f16_bf16.cuh" +#include "matmul_f16_bf16_tc.cuh" +#include "matmul_f16_bf16_tc_generic.cuh" #include "../core/driver_context.hpp" #include #include @@ -1878,29 +1880,37 @@ void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c) { throw std::runtime_error("matmul output shape mismatch"); } - // Check for TF32 TensorCore mode (requires SM >= 80) + // Check for TensorCore modes (requires SM >= 80) // Note: Check on every call since env var might change bool tf32_enabled = false; + bool fp16_tc_enabled = false; int sm_version = 0; - // Check environment variable + // Check environment variables const char* tf32_env = std::getenv("PYGPUKIT_ALLOW_TF32"); + const char* fp16_tc_env = std::getenv("PYGPUKIT_ALLOW_FP16_TC"); // Debug output (remove in production) static bool debug_printed = false; if (!debug_printed) { debug_printed = true; printf("[PyGPUkit] PYGPUKIT_ALLOW_TF32 = %s\n", tf32_env ? tf32_env : "(null)"); + printf("[PyGPUkit] PYGPUKIT_ALLOW_FP16_TC = %s\n", fp16_tc_env ? fp16_tc_env : "(null)"); fflush(stdout); } - if (tf32_env && (tf32_env[0] == '1' || tf32_env[0] == 'y' || tf32_env[0] == 'Y')) { - // Check GPU compute capability (using internal helper for driver-only compatibility) + // Check SM version once if any TensorCore mode is requested + if ((tf32_env && (tf32_env[0] == '1' || tf32_env[0] == 'y' || tf32_env[0] == 'Y')) || + (fp16_tc_env && (fp16_tc_env[0] == '1' || fp16_tc_env[0] == 'y' || fp16_tc_env[0] == 'Y'))) { sm_version = get_sm_version_internal(); + } + + if (tf32_env && (tf32_env[0] == '1' || tf32_env[0] == 'y' || tf32_env[0] == 'Y')) { tf32_enabled = (sm_version >= 80); // Ampere or newer - if (!debug_printed) { - fprintf(stderr, "[PyGPUkit] SM version = %d, TF32 enabled = %d\n", sm_version, tf32_enabled); - } + } + + if (fp16_tc_env && (fp16_tc_env[0] == '1' || fp16_tc_env[0] == 'y' || fp16_tc_env[0] == 'Y')) { + fp16_tc_enabled = (sm_version >= 80); // Ampere or newer } // Select kernel based on matrix size and dtype @@ -1912,13 +1922,27 @@ void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c) { K >= OPTIMIZED_MATMUL_THRESHOLD) || (M == 16 && (N == 8 || N == 16))); - bool use_optimized = !use_tf32 && + // FP16/BF16 TensorCore FAST: requires sizes to be exact multiples of tile size + // BM=128, BN=128, BK=32 in fp16_bf16_tc namespace + bool use_fp16_tc_fast = fp16_tc_enabled && + (a.dtype() == DataType::Float16 || a.dtype() == DataType::BFloat16) && + (M >= 128 && N >= 128 && K >= 32) && + (M % 128 == 0 && N % 128 == 0 && K % 32 == 0); + + // FP16/BF16 TensorCore GENERIC: supports M,N >= 16, K % 8 == 0 + // Slower than FAST but more flexible + bool use_fp16_tc_generic = !use_fp16_tc_fast && fp16_tc_enabled && + (a.dtype() == DataType::Float16 || a.dtype() == DataType::BFloat16) && + (M >= 16 && N >= 16 && K >= 8) && + (K % 8 == 0); + + bool use_optimized = !use_tf32 && !use_fp16_tc_fast && !use_fp16_tc_generic && (a.dtype() == DataType::Float32) && (M >= OPTIMIZED_MATMUL_THRESHOLD || N >= OPTIMIZED_MATMUL_THRESHOLD || K >= OPTIMIZED_MATMUL_THRESHOLD); - bool use_tiled = !use_optimized && !use_tf32 && + bool use_tiled = !use_optimized && !use_tf32 && !use_fp16_tc_fast && !use_fp16_tc_generic && (M >= TILED_MATMUL_THRESHOLD || N >= TILED_MATMUL_THRESHOLD || K >= TILED_MATMUL_THRESHOLD); @@ -1940,6 +1964,36 @@ void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c) { static_cast(c.data()), M, N, K); } + } else if (use_fp16_tc_fast) { + // FP16/BF16 TensorCore FAST kernels with mma.sync.m16n8k16 + if (a.dtype() == DataType::Float16) { + fp16_bf16_tc::launch_sgemm_f16_tc( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + M, N, K); + } else { + fp16_bf16_tc::launch_sgemm_bf16_tc( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + M, N, K); + } + } else if (use_fp16_tc_generic) { + // FP16/BF16 TensorCore GENERIC kernels with mma.sync.m16n8k8 (boundary handling) + if (a.dtype() == DataType::Float16) { + fp16_bf16_tc_generic::launch_sgemm_f16_tc_generic( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + M, N, K); + } else { + fp16_bf16_tc_generic::launch_sgemm_bf16_tc_generic( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + M, N, K); + } } else if (use_optimized) { // Ampere-optimized FP32 FMA kernel with cp.async and 4-stage pipeline ampere::launch_sgemm_ampere( diff --git a/native/ops/common/device.cuh b/native/ops/common/device.cuh new file mode 100644 index 0000000..9d3a31e --- /dev/null +++ b/native/ops/common/device.cuh @@ -0,0 +1,23 @@ +/** + * Device capability helpers + */ +#pragma once + +#include +#include "../../core/driver_context.hpp" + +namespace pygpukit { +namespace ops { + +// Get SM version (e.g., 80 for SM 8.0) +inline int get_sm_version() { + auto& ctx = driver::DriverContext::instance(); + CUdevice device = ctx.get_device(ctx.current_device()); + int major = 0, minor = 0; + cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device); + cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); + return major * 10 + minor; +} + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/common/error.cuh b/native/ops/common/error.cuh new file mode 100644 index 0000000..ca7c0ba --- /dev/null +++ b/native/ops/common/error.cuh @@ -0,0 +1,53 @@ +/** + * Error handling and validation helpers + */ +#pragma once + +#include +#include +#include +#include "../../core/memory.hpp" + +namespace pygpukit { +namespace ops { + +// CUDA Driver API error check +inline void check_driver_error(CUresult result, const char* msg) { + if (result != CUDA_SUCCESS) { + const char* error_str = nullptr; + cuGetErrorString(result, &error_str); + throw CudaError(std::string(msg) + ": " + (error_str ? error_str : "unknown error")); + } +} + +// Synchronize and check for errors +inline void sync_and_check(const char* msg) { + check_driver_error(cuCtxSynchronize(), msg); +} + +// Shape validation +inline void validate_same_shape(const GPUArray& a, const GPUArray& b, const char* op_name) { + if (a.shape() != b.shape()) { + throw std::runtime_error(std::string(op_name) + " requires arrays of same shape"); + } +} + +// Dtype validation +inline void validate_same_dtype(const GPUArray& a, const GPUArray& b, const char* op_name) { + if (a.dtype() != b.dtype()) { + throw std::runtime_error(std::string(op_name) + " requires arrays of same dtype"); + } +} + +// Matmul shape validation +inline void validate_matmul_shapes(const GPUArray& a, const GPUArray& b, const char* op_name) { + if (a.ndim() != 2 || b.ndim() != 2) { + throw std::runtime_error(std::string(op_name) + " requires 2D arrays"); + } + if (a.shape()[1] != b.shape()[0]) { + throw std::runtime_error(std::string(op_name) + " dimension mismatch"); + } +} + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/common/types.cuh b/native/ops/common/types.cuh new file mode 100644 index 0000000..582cc8e --- /dev/null +++ b/native/ops/common/types.cuh @@ -0,0 +1,34 @@ +/** + * Common type definitions and conversion helpers + */ +#pragma once + +#include +#include +#include + +namespace pygpukit { +namespace ops { + +// BF16 conversion helpers (avoid constexpr __host__ issues) +__device__ __forceinline__ float bf16_to_float(__nv_bfloat16 val) { + unsigned short raw; + memcpy(&raw, &val, sizeof(raw)); + unsigned int bits = ((unsigned int)raw) << 16; + float result; + memcpy(&result, &bits, sizeof(result)); + return result; +} + +__device__ __forceinline__ __nv_bfloat16 float_to_bf16(float val) { + unsigned int bits; + memcpy(&bits, &val, sizeof(bits)); + bits += 0x7FFF + ((bits >> 16) & 1); // Round to nearest even + unsigned short raw = (unsigned short)(bits >> 16); + __nv_bfloat16 result; + memcpy(&result, &raw, sizeof(result)); + return result; +} + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/elementwise/elementwise.cu b/native/ops/elementwise/elementwise.cu new file mode 100644 index 0000000..a9c6df7 --- /dev/null +++ b/native/ops/elementwise/elementwise.cu @@ -0,0 +1,266 @@ +/** + * Elementwise binary operations dispatch + */ +#include "elementwise_kernels.cuh" +#include "../common/error.cuh" +#include "../../core/memory.hpp" + +namespace pygpukit { +namespace ops { + +using namespace elementwise; + +// ============================================================================ +// Add +// ============================================================================ + +void add(const GPUArray& a, const GPUArray& b, GPUArray& c) { + validate_same_shape(a, b, "add"); + validate_same_dtype(a, b, "add"); + validate_same_shape(a, c, "add"); + validate_same_dtype(a, c, "add"); + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + add_f32_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Float64: + add_f64_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Int32: + add_i32_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Int64: + add_i64_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + add_f16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + add_bf16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + } + sync_and_check("add kernel failed"); +} + +GPUArray add(const GPUArray& a, const GPUArray& b) { + validate_same_shape(a, b, "add"); + validate_same_dtype(a, b, "add"); + GPUArray c(a.shape(), a.dtype()); + add(a, b, c); + return c; +} + +// ============================================================================ +// Mul +// ============================================================================ + +void mul(const GPUArray& a, const GPUArray& b, GPUArray& c) { + validate_same_shape(a, b, "mul"); + validate_same_dtype(a, b, "mul"); + validate_same_shape(a, c, "mul"); + validate_same_dtype(a, c, "mul"); + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + mul_f32_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Float64: + mul_f64_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Int32: + mul_i32_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Int64: + mul_i64_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + mul_f16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + mul_bf16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + } + sync_and_check("mul kernel failed"); +} + +GPUArray mul(const GPUArray& a, const GPUArray& b) { + validate_same_shape(a, b, "mul"); + validate_same_dtype(a, b, "mul"); + GPUArray c(a.shape(), a.dtype()); + mul(a, b, c); + return c; +} + +// ============================================================================ +// Sub +// ============================================================================ + +void sub(const GPUArray& a, const GPUArray& b, GPUArray& c) { + validate_same_shape(a, b, "sub"); + validate_same_dtype(a, b, "sub"); + validate_same_shape(a, c, "sub"); + validate_same_dtype(a, c, "sub"); + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + sub_f32_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Float64: + sub_f64_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Int32: + sub_i32_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Int64: + sub_i64_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + sub_f16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + sub_bf16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + } + sync_and_check("sub kernel failed"); +} + +GPUArray sub(const GPUArray& a, const GPUArray& b) { + validate_same_shape(a, b, "sub"); + validate_same_dtype(a, b, "sub"); + GPUArray c(a.shape(), a.dtype()); + sub(a, b, c); + return c; +} + +// ============================================================================ +// Div +// ============================================================================ + +void div(const GPUArray& a, const GPUArray& b, GPUArray& c) { + validate_same_shape(a, b, "div"); + validate_same_dtype(a, b, "div"); + validate_same_shape(a, c, "div"); + validate_same_dtype(a, c, "div"); + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + div_f32_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Float64: + div_f64_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Int32: + div_i32_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Int64: + div_i64_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + div_f16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + div_bf16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + } + sync_and_check("div kernel failed"); +} + +GPUArray div(const GPUArray& a, const GPUArray& b) { + validate_same_shape(a, b, "div"); + validate_same_dtype(a, b, "div"); + GPUArray c(a.shape(), a.dtype()); + div(a, b, c); + return c; +} + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/elementwise/elementwise_kernels.cuh b/native/ops/elementwise/elementwise_kernels.cuh new file mode 100644 index 0000000..64dd689 --- /dev/null +++ b/native/ops/elementwise/elementwise_kernels.cuh @@ -0,0 +1,202 @@ +/** + * Elementwise binary operation kernels (add, mul, sub, div) + */ +#pragma once + +#include +#include +#include +#include +#include "../common/types.cuh" + +namespace pygpukit { +namespace ops { +namespace elementwise { + +// ============================================================================ +// Add kernels +// ============================================================================ + +__global__ void add_f32_kernel(const float* a, const float* b, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] + b[idx]; + } +} + +__global__ void add_f64_kernel(const double* a, const double* b, double* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] + b[idx]; + } +} + +__global__ void add_i32_kernel(const int32_t* a, const int32_t* b, int32_t* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] + b[idx]; + } +} + +__global__ void add_i64_kernel(const int64_t* a, const int64_t* b, int64_t* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] + b[idx]; + } +} + +__global__ void add_f16_kernel(const __half* a, const __half* b, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __hadd(a[idx], b[idx]); + } +} + +__global__ void add_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = float_to_bf16(bf16_to_float(a[idx]) + bf16_to_float(b[idx])); + } +} + +// ============================================================================ +// Mul kernels +// ============================================================================ + +__global__ void mul_f32_kernel(const float* a, const float* b, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] * b[idx]; + } +} + +__global__ void mul_f64_kernel(const double* a, const double* b, double* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] * b[idx]; + } +} + +__global__ void mul_i32_kernel(const int32_t* a, const int32_t* b, int32_t* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] * b[idx]; + } +} + +__global__ void mul_i64_kernel(const int64_t* a, const int64_t* b, int64_t* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] * b[idx]; + } +} + +__global__ void mul_f16_kernel(const __half* a, const __half* b, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __hmul(a[idx], b[idx]); + } +} + +__global__ void mul_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = float_to_bf16(bf16_to_float(a[idx]) * bf16_to_float(b[idx])); + } +} + +// ============================================================================ +// Sub kernels +// ============================================================================ + +__global__ void sub_f32_kernel(const float* a, const float* b, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] - b[idx]; + } +} + +__global__ void sub_f64_kernel(const double* a, const double* b, double* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] - b[idx]; + } +} + +__global__ void sub_i32_kernel(const int32_t* a, const int32_t* b, int32_t* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] - b[idx]; + } +} + +__global__ void sub_i64_kernel(const int64_t* a, const int64_t* b, int64_t* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] - b[idx]; + } +} + +__global__ void sub_f16_kernel(const __half* a, const __half* b, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __hsub(a[idx], b[idx]); + } +} + +__global__ void sub_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = float_to_bf16(bf16_to_float(a[idx]) - bf16_to_float(b[idx])); + } +} + +// ============================================================================ +// Div kernels +// ============================================================================ + +__global__ void div_f32_kernel(const float* a, const float* b, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] / b[idx]; + } +} + +__global__ void div_f64_kernel(const double* a, const double* b, double* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] / b[idx]; + } +} + +__global__ void div_i32_kernel(const int32_t* a, const int32_t* b, int32_t* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] / b[idx]; + } +} + +__global__ void div_i64_kernel(const int64_t* a, const int64_t* b, int64_t* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] / b[idx]; + } +} + +__global__ void div_f16_kernel(const __half* a, const __half* b, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __float2half(__half2float(a[idx]) / __half2float(b[idx])); + } +} + +__global__ void div_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = float_to_bf16(bf16_to_float(a[idx]) / bf16_to_float(b[idx])); + } +} + +} // namespace elementwise +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/matmul/matmul.cu b/native/ops/matmul/matmul.cu new file mode 100644 index 0000000..7afa1a1 --- /dev/null +++ b/native/ops/matmul/matmul.cu @@ -0,0 +1,461 @@ +/** + * Matrix multiplication dispatch + */ +#include "matmul_fp32.cuh" +#include "../common/error.cuh" +#include "../common/device.cuh" +#include "../../core/memory.hpp" + +// Include existing optimized kernels +#include "../matmul_f32_ampere.cuh" +#include "../matmul_f32_tf32.cuh" +#include "../matmul_f32_tf32_v2.cuh" +#include "../matmul_f16_bf16.cuh" +#include "../matmul_f16_bf16_tc.cuh" +#include "../matmul_f16_bf16_tc_generic.cuh" + +#include +#include + +// CUTLASS GEMM (extern declarations from matmul_cutlass.cu) +extern "C" { + cudaError_t cutlass_gemm_tf32(const float* A, const float* B, float* C, int M, int N, int K, cudaStream_t stream); + cudaError_t cutlass_gemm_fp16(const __half* A, const __half* B, __half* C, int M, int N, int K, cudaStream_t stream); + cudaError_t cutlass_gemm_bf16(const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C, int M, int N, int K, cudaStream_t stream); + bool cutlass_is_compatible(int M, int N, int K); +} + +namespace pygpukit { +namespace ops { + +// Thresholds for kernel selection +constexpr int TILED_MATMUL_THRESHOLD = 128; +constexpr int OPTIMIZED_MATMUL_THRESHOLD = 128; + +void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c) { + validate_matmul_shapes(a, b, "matmul"); + validate_same_dtype(a, b, "matmul"); + + size_t M = a.shape()[0]; + size_t K = a.shape()[1]; + size_t N = b.shape()[1]; + + if (c.shape()[0] != M || c.shape()[1] != N) { + throw std::runtime_error("matmul output shape mismatch"); + } + + // v0.2.6: CUTLASS is the default backend + // Environment variables: + // PYGPUKIT_NO_CUTLASS=1 - Disable CUTLASS entirely, use native kernels + // PYGPUKIT_NO_TF32=1 - Disable TF32 for FP32 inputs (use native FP32 kernel) + const char* no_cutlass_env = std::getenv("PYGPUKIT_NO_CUTLASS"); + const char* no_tf32_env = std::getenv("PYGPUKIT_NO_TF32"); + + bool cutlass_disabled = no_cutlass_env && + (no_cutlass_env[0] == '1' || no_cutlass_env[0] == 'y' || no_cutlass_env[0] == 'Y'); + bool tf32_disabled = no_tf32_env && + (no_tf32_env[0] == '1' || no_tf32_env[0] == 'y' || no_tf32_env[0] == 'Y'); + + // CUTLASS enabled by default if dimensions are compatible + // For FP32: skip CUTLASS TF32 if NO_TF32 is set (will use native FP32 kernel) + bool cutlass_enabled = !cutlass_disabled && cutlass_is_compatible(M, N, K); + bool cutlass_tf32_enabled = cutlass_enabled && !tf32_disabled; + + // Fallback to native TensorCore kernels + bool tf32_enabled = false; + bool fp16_tc_enabled = false; + int sm_version = 0; + + // Only check native TensorCore settings if CUTLASS is disabled + if (!cutlass_enabled) { + const char* tf32_env = std::getenv("PYGPUKIT_ALLOW_TF32"); + const char* fp16_tc_env = std::getenv("PYGPUKIT_ALLOW_FP16_TC"); + + if ((tf32_env && (tf32_env[0] == '1' || tf32_env[0] == 'y' || tf32_env[0] == 'Y')) || + (fp16_tc_env && (fp16_tc_env[0] == '1' || fp16_tc_env[0] == 'y' || fp16_tc_env[0] == 'Y'))) { + sm_version = get_sm_version(); + } + + if (tf32_env && (tf32_env[0] == '1' || tf32_env[0] == 'y' || tf32_env[0] == 'Y')) { + tf32_enabled = (sm_version >= 80); + } + + if (fp16_tc_env && (fp16_tc_env[0] == '1' || fp16_tc_env[0] == 'y' || fp16_tc_env[0] == 'Y')) { + fp16_tc_enabled = (sm_version >= 80); + } + } + + // Kernel selection + bool use_tf32 = tf32_enabled && + (a.dtype() == DataType::Float32) && + ((M >= OPTIMIZED_MATMUL_THRESHOLD && + N >= OPTIMIZED_MATMUL_THRESHOLD && + K >= OPTIMIZED_MATMUL_THRESHOLD) || + (M == 16 && (N == 8 || N == 16))); + + bool use_fp16_tc_fast = fp16_tc_enabled && + (a.dtype() == DataType::Float16 || a.dtype() == DataType::BFloat16) && + (M >= 128 && N >= 128 && K >= 32) && + (M % 128 == 0 && N % 128 == 0 && K % 32 == 0); + + bool use_fp16_tc_generic = !use_fp16_tc_fast && fp16_tc_enabled && + (a.dtype() == DataType::Float16 || a.dtype() == DataType::BFloat16) && + (M >= 16 && N >= 16 && K >= 8) && + (K % 8 == 0); + + bool use_optimized = !use_tf32 && !use_fp16_tc_fast && !use_fp16_tc_generic && + (a.dtype() == DataType::Float32) && + (M >= OPTIMIZED_MATMUL_THRESHOLD || + N >= OPTIMIZED_MATMUL_THRESHOLD || + K >= OPTIMIZED_MATMUL_THRESHOLD); + + bool use_tiled = !use_optimized && !use_tf32 && !use_fp16_tc_fast && !use_fp16_tc_generic && + (M >= TILED_MATMUL_THRESHOLD || + N >= TILED_MATMUL_THRESHOLD || + K >= TILED_MATMUL_THRESHOLD); + + // CUTLASS dispatch (highest priority when enabled) + // FP32 uses TF32 TensorCore (can be disabled with PYGPUKIT_NO_TF32) + // FP16/BF16 always use CUTLASS when available + if (cutlass_enabled || cutlass_tf32_enabled) { + cudaError_t err = cudaSuccess; + bool used_cutlass = false; + + switch (a.dtype()) { + case DataType::Float32: + if (cutlass_tf32_enabled) { + err = cutlass_gemm_tf32( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K, nullptr); + used_cutlass = true; + } + break; + case DataType::Float16: + if (cutlass_enabled) { + err = cutlass_gemm_fp16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + M, N, K, nullptr); + used_cutlass = true; + } + break; + case DataType::BFloat16: + if (cutlass_enabled) { + err = cutlass_gemm_bf16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + M, N, K, nullptr); + used_cutlass = true; + } + break; + default: + break; + } + + if (used_cutlass) { + if (err != cudaSuccess) { + throw std::runtime_error("CUTLASS GEMM failed"); + } + sync_and_check("CUTLASS matmul kernel failed"); + return; + } + } + + if (use_tf32) { + if (M == 16 && (N == 8 || N == 16)) { + tf32::launch_single_tile_verified( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + } else { + tf32::launch_sgemm_tf32( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + } + } else if (use_fp16_tc_fast) { + if (a.dtype() == DataType::Float16) { + fp16_bf16_tc::launch_sgemm_f16_tc( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + M, N, K); + } else { + fp16_bf16_tc::launch_sgemm_bf16_tc( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + M, N, K); + } + } else if (use_fp16_tc_generic) { + if (a.dtype() == DataType::Float16) { + fp16_bf16_tc_generic::launch_sgemm_f16_tc_generic( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + M, N, K); + } else { + fp16_bf16_tc_generic::launch_sgemm_bf16_tc_generic( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + M, N, K); + } + } else if (use_optimized) { + ampere::launch_sgemm_ampere( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + } else if (use_tiled) { + switch (a.dtype()) { + case DataType::Float32: + matmul_fp32::launch_tiled_f32( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + break; + case DataType::Float64: + matmul_fp32::launch_tiled_f64( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + break; + case DataType::Float16: + fp16_bf16_matmul::launch_sgemm_f16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + M, N, K); + break; + case DataType::BFloat16: + fp16_bf16_matmul::launch_sgemm_bf16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + M, N, K); + break; + default: + throw std::runtime_error("matmul only supports float types"); + } + } else { + switch (a.dtype()) { + case DataType::Float32: + matmul_fp32::launch_l2opt_f32( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + break; + case DataType::Float64: + matmul_fp32::launch_l2opt_f64( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + break; + case DataType::Float16: + fp16_bf16_matmul::launch_sgemm_f16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + M, N, K); + break; + case DataType::BFloat16: + fp16_bf16_matmul::launch_sgemm_bf16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + M, N, K); + break; + default: + throw std::runtime_error("matmul only supports float types"); + } + } + + sync_and_check("matmul kernel failed"); +} + +GPUArray matmul(const GPUArray& a, const GPUArray& b) { + validate_matmul_shapes(a, b, "matmul"); + validate_same_dtype(a, b, "matmul"); + + size_t M = a.shape()[0]; + size_t N = b.shape()[1]; + + GPUArray c({M, N}, a.dtype()); + matmul(a, b, c); + return c; +} + +// Internal helper: matmul with explicit TF32 control +static void matmul_impl(const GPUArray& a, const GPUArray& b, GPUArray& c, bool use_tf32_explicit) { + validate_matmul_shapes(a, b, "matmul"); + validate_same_dtype(a, b, "matmul"); + + size_t M = a.shape()[0]; + size_t K = a.shape()[1]; + size_t N = b.shape()[1]; + + if (c.shape()[0] != M || c.shape()[1] != N) { + throw std::runtime_error("matmul output shape mismatch"); + } + + int sm_version = get_sm_version(); + + bool tf32_enabled = use_tf32_explicit && + (a.dtype() == DataType::Float32) && + (sm_version >= 80); + + if (use_tf32_explicit && !tf32_enabled) { + if (a.dtype() != DataType::Float32) { + throw std::runtime_error("TF32 matmul requires float32 dtype"); + } + if (sm_version < 80) { + throw std::runtime_error("TF32 matmul requires SM >= 80 (Ampere or newer)"); + } + } + + bool use_tf32 = tf32_enabled && + ((M >= OPTIMIZED_MATMUL_THRESHOLD && + N >= OPTIMIZED_MATMUL_THRESHOLD && + K >= OPTIMIZED_MATMUL_THRESHOLD) || + (M == 16 && (N == 8 || N == 16))); + + bool use_optimized = !use_tf32 && + (a.dtype() == DataType::Float32) && + (M >= OPTIMIZED_MATMUL_THRESHOLD || + N >= OPTIMIZED_MATMUL_THRESHOLD || + K >= OPTIMIZED_MATMUL_THRESHOLD); + + bool use_tiled = !use_optimized && !use_tf32 && + (M >= TILED_MATMUL_THRESHOLD || + N >= TILED_MATMUL_THRESHOLD || + K >= TILED_MATMUL_THRESHOLD); + + if (use_tf32) { + if (M == 16 && (N == 8 || N == 16)) { + tf32::launch_single_tile_verified( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + } else { + const char* use_v2 = std::getenv("PYGPUKIT_TF32_V2"); + if (use_v2 && std::string(use_v2) == "1") { + tf32_v2::launch_sgemm_tf32_v2( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + } else { + tf32::launch_sgemm_tf32( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + } + } + } else if (use_optimized) { + ampere::launch_sgemm_ampere( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + } else if (use_tiled) { + switch (a.dtype()) { + case DataType::Float32: + matmul_fp32::launch_tiled_f32( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + break; + case DataType::Float64: + matmul_fp32::launch_tiled_f64( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + break; + case DataType::Float16: + fp16_bf16_matmul::launch_sgemm_f16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + M, N, K); + break; + case DataType::BFloat16: + fp16_bf16_matmul::launch_sgemm_bf16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + M, N, K); + break; + default: + throw std::runtime_error("matmul only supports float32, float64, float16, and bfloat16"); + } + } else { + switch (a.dtype()) { + case DataType::Float32: + matmul_fp32::launch_l2opt_f32( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + break; + case DataType::Float64: + matmul_fp32::launch_l2opt_f64( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + break; + case DataType::Float16: + fp16_bf16_matmul::launch_sgemm_f16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + M, N, K); + break; + case DataType::BFloat16: + fp16_bf16_matmul::launch_sgemm_bf16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + M, N, K); + break; + default: + throw std::runtime_error("matmul only supports float32, float64, float16, and bfloat16"); + } + } + + sync_and_check("matmul kernel failed"); +} + +void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c, bool use_tf32) { + matmul_impl(a, b, c, use_tf32); +} + +GPUArray matmul(const GPUArray& a, const GPUArray& b, bool use_tf32) { + validate_matmul_shapes(a, b, "matmul"); + validate_same_dtype(a, b, "matmul"); + + size_t M = a.shape()[0]; + size_t N = b.shape()[1]; + + GPUArray c({M, N}, a.dtype()); + matmul_impl(a, b, c, use_tf32); + return c; +} + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/matmul/matmul_cutlass.cu b/native/ops/matmul/matmul_cutlass.cu new file mode 100644 index 0000000..0c83de5 --- /dev/null +++ b/native/ops/matmul/matmul_cutlass.cu @@ -0,0 +1,106 @@ +/** + * CUTLASS GEMM instantiation for PyGPUkit + * + * This file instantiates CUTLASS templates for SM 86. + * Separated from main matmul.cu to isolate template compilation. + */ + +#include +#include +#include + +#if PYGPUKIT_HAS_CUTLASS + +#include "../matmul_cutlass.cuh" + +namespace pygpukit { +namespace ops { + +// ============================================================================ +// Explicit C-linkage wrappers for CUTLASS GEMM +// These can be called from the main matmul dispatch +// ============================================================================ + +extern "C" { + +cudaError_t cutlass_gemm_tf32( + const float* A, + const float* B, + float* C, + int M, int N, int K, + cudaStream_t stream +) { + return cutlass_gemm::gemm_tf32(A, B, C, M, N, K, 1.0f, 0.0f, stream); +} + +cudaError_t cutlass_gemm_fp16( + const __half* A, + const __half* B, + __half* C, + int M, int N, int K, + cudaStream_t stream +) { + return cutlass_gemm::gemm_fp16(A, B, C, M, N, K, 1.0f, 0.0f, stream); +} + +cudaError_t cutlass_gemm_bf16( + const __nv_bfloat16* A, + const __nv_bfloat16* B, + __nv_bfloat16* C, + int M, int N, int K, + cudaStream_t stream +) { + return cutlass_gemm::gemm_bf16(A, B, C, M, N, K, 1.0f, 0.0f, stream); +} + +bool cutlass_is_compatible(int M, int N, int K) { + return cutlass_gemm::is_cutlass_compatible(M, N, K); +} + +} // extern "C" + +} // namespace ops +} // namespace pygpukit + +#else // !PYGPUKIT_HAS_CUTLASS + +// Stub implementations when CUTLASS is not available +extern "C" { + +cudaError_t cutlass_gemm_tf32( + const float* A, + const float* B, + float* C, + int M, int N, int K, + cudaStream_t stream +) { + return cudaErrorNotSupported; +} + +cudaError_t cutlass_gemm_fp16( + const __half* A, + const __half* B, + __half* C, + int M, int N, int K, + cudaStream_t stream +) { + return cudaErrorNotSupported; +} + +cudaError_t cutlass_gemm_bf16( + const __nv_bfloat16* A, + const __nv_bfloat16* B, + __nv_bfloat16* C, + int M, int N, int K, + cudaStream_t stream +) { + return cudaErrorNotSupported; +} + +bool cutlass_is_compatible(int M, int N, int K) { + return false; +} + +} // extern "C" + +#endif // PYGPUKIT_HAS_CUTLASS diff --git a/native/ops/matmul/matmul_fp32.cuh b/native/ops/matmul/matmul_fp32.cuh new file mode 100644 index 0000000..8ffa21d --- /dev/null +++ b/native/ops/matmul/matmul_fp32.cuh @@ -0,0 +1,383 @@ +/** + * FP32/FP64 Matrix Multiplication Kernels + * + * Three implementations: + * 1. L2-optimized kernel: For small matrices (<128), uses __ldg() cache + * 2. Tiled kernel: For medium matrices, uses shared memory double buffering + * 3. Optimized kernel: For large matrices (>=128), high-performance SGEMM + */ +#pragma once + +#include +#include + +namespace pygpukit { +namespace ops { +namespace matmul_fp32 { + +// Block size for L2-optimized kernel +constexpr int BLOCK_SIZE = 16; + +// Tiled matmul configuration +constexpr int TILE_M = 64; // Output tile height +constexpr int TILE_N = 64; // Output tile width +constexpr int TILE_K = 16; // Reduction tile depth +constexpr int THREAD_M = 4; // Elements per thread in M dimension +constexpr int THREAD_N = 4; // Elements per thread in N dimension + +// ============================================================================ +// L2-Optimized Kernels (Small Matrices) +// ============================================================================ + +__global__ void matmul_f32_l2opt_kernel( + const float* __restrict__ A, + const float* __restrict__ B, + float* __restrict__ C, + size_t M, size_t N, size_t K +) { + const size_t row = blockIdx.y * blockDim.y + threadIdx.y; + const size_t col = blockIdx.x * blockDim.x + threadIdx.x; + + if (row < M && col < N) { + float sum = 0.0f; + #pragma unroll 4 + for (size_t k = 0; k < K; ++k) { + sum += __ldg(&A[row * K + k]) * __ldg(&B[k * N + col]); + } + C[row * N + col] = sum; + } +} + +__global__ void matmul_f64_l2opt_kernel( + const double* __restrict__ A, + const double* __restrict__ B, + double* __restrict__ C, + size_t M, size_t N, size_t K +) { + const size_t row = blockIdx.y * blockDim.y + threadIdx.y; + const size_t col = blockIdx.x * blockDim.x + threadIdx.x; + + if (row < M && col < N) { + double sum = 0.0; + #pragma unroll 4 + for (size_t k = 0; k < K; ++k) { + sum += __ldg(&A[row * K + k]) * __ldg(&B[k * N + col]); + } + C[row * N + col] = sum; + } +} + +// ============================================================================ +// Tiled Kernel with Double Buffering (FP32) +// ============================================================================ + +__global__ void matmul_f32_tiled_kernel( + const float* __restrict__ A, + const float* __restrict__ B, + float* __restrict__ C, + size_t M, size_t N, size_t K +) { + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int bx = blockIdx.x; + const int by = blockIdx.y; + + __shared__ float As[2][TILE_K][TILE_M + 1]; + __shared__ float Bs[2][TILE_K][TILE_N + 1]; + + float accum[THREAD_M][THREAD_N] = {{0.0f}}; + + const size_t block_row_start = by * TILE_M; + const size_t block_col_start = bx * TILE_N; + + const int tid = ty * blockDim.x + tx; + const int num_threads = blockDim.x * blockDim.y; + const int num_k_tiles = (K + TILE_K - 1) / TILE_K; + + int curr_buf = 0; + + // Prefetch first tile + { + const int a_loads_per_thread = (TILE_M * TILE_K + num_threads - 1) / num_threads; + for (int i = 0; i < a_loads_per_thread; ++i) { + int load_idx = tid + i * num_threads; + if (load_idx < TILE_M * TILE_K) { + int a_row = load_idx / TILE_K; + int a_col = load_idx % TILE_K; + size_t global_row = block_row_start + a_row; + size_t global_col = a_col; + if (global_row < M && global_col < K) { + As[0][a_col][a_row] = A[global_row * K + global_col]; + } else { + As[0][a_col][a_row] = 0.0f; + } + } + } + + const int b_loads_per_thread = (TILE_K * TILE_N + num_threads - 1) / num_threads; + for (int i = 0; i < b_loads_per_thread; ++i) { + int load_idx = tid + i * num_threads; + if (load_idx < TILE_K * TILE_N) { + int b_row = load_idx / TILE_N; + int b_col = load_idx % TILE_N; + size_t global_row = b_row; + size_t global_col = block_col_start + b_col; + if (global_row < K && global_col < N) { + Bs[0][b_row][b_col] = B[global_row * N + global_col]; + } else { + Bs[0][b_row][b_col] = 0.0f; + } + } + } + } + __syncthreads(); + + for (int tile_k = 0; tile_k < num_k_tiles; ++tile_k) { + int next_buf = 1 - curr_buf; + + if (tile_k + 1 < num_k_tiles) { + size_t k_offset = (tile_k + 1) * TILE_K; + + const int a_loads_per_thread = (TILE_M * TILE_K + num_threads - 1) / num_threads; + for (int i = 0; i < a_loads_per_thread; ++i) { + int load_idx = tid + i * num_threads; + if (load_idx < TILE_M * TILE_K) { + int a_row = load_idx / TILE_K; + int a_col = load_idx % TILE_K; + size_t global_row = block_row_start + a_row; + size_t global_col = k_offset + a_col; + if (global_row < M && global_col < K) { + As[next_buf][a_col][a_row] = A[global_row * K + global_col]; + } else { + As[next_buf][a_col][a_row] = 0.0f; + } + } + } + + const int b_loads_per_thread = (TILE_K * TILE_N + num_threads - 1) / num_threads; + for (int i = 0; i < b_loads_per_thread; ++i) { + int load_idx = tid + i * num_threads; + if (load_idx < TILE_K * TILE_N) { + int b_row = load_idx / TILE_N; + int b_col = load_idx % TILE_N; + size_t global_row = k_offset + b_row; + size_t global_col = block_col_start + b_col; + if (global_row < K && global_col < N) { + Bs[next_buf][b_row][b_col] = B[global_row * N + global_col]; + } else { + Bs[next_buf][b_row][b_col] = 0.0f; + } + } + } + } + + #pragma unroll + for (int k = 0; k < TILE_K; ++k) { + float a_frag[THREAD_M]; + #pragma unroll + for (int m = 0; m < THREAD_M; ++m) { + a_frag[m] = As[curr_buf][k][ty * THREAD_M + m]; + } + + #pragma unroll + for (int n = 0; n < THREAD_N; ++n) { + float b_val = Bs[curr_buf][k][tx * THREAD_N + n]; + #pragma unroll + for (int m = 0; m < THREAD_M; ++m) { + accum[m][n] += a_frag[m] * b_val; + } + } + } + + __syncthreads(); + curr_buf = next_buf; + } + + #pragma unroll + for (int m = 0; m < THREAD_M; ++m) { + size_t out_row = block_row_start + ty * THREAD_M + m; + if (out_row < M) { + #pragma unroll + for (int n = 0; n < THREAD_N; ++n) { + size_t out_col = block_col_start + tx * THREAD_N + n; + if (out_col < N) { + C[out_row * N + out_col] = accum[m][n]; + } + } + } + } +} + +// ============================================================================ +// Tiled Kernel with Double Buffering (FP64) +// ============================================================================ + +__global__ void matmul_f64_tiled_kernel( + const double* __restrict__ A, + const double* __restrict__ B, + double* __restrict__ C, + size_t M, size_t N, size_t K +) { + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int bx = blockIdx.x; + const int by = blockIdx.y; + + constexpr int TILE_K_F64 = 8; + __shared__ double As[2][TILE_K_F64][TILE_M + 1]; + __shared__ double Bs[2][TILE_K_F64][TILE_N + 1]; + + double accum[THREAD_M][THREAD_N] = {{0.0}}; + + const size_t block_row_start = by * TILE_M; + const size_t block_col_start = bx * TILE_N; + + const int tid = ty * blockDim.x + tx; + const int num_threads = blockDim.x * blockDim.y; + const int num_k_tiles = (K + TILE_K_F64 - 1) / TILE_K_F64; + + int curr_buf = 0; + + // Prefetch first tile + { + const int a_loads_per_thread = (TILE_M * TILE_K_F64 + num_threads - 1) / num_threads; + for (int i = 0; i < a_loads_per_thread; ++i) { + int load_idx = tid + i * num_threads; + if (load_idx < TILE_M * TILE_K_F64) { + int a_row = load_idx / TILE_K_F64; + int a_col = load_idx % TILE_K_F64; + size_t global_row = block_row_start + a_row; + size_t global_col = a_col; + if (global_row < M && global_col < K) { + As[0][a_col][a_row] = A[global_row * K + global_col]; + } else { + As[0][a_col][a_row] = 0.0; + } + } + } + + const int b_loads_per_thread = (TILE_K_F64 * TILE_N + num_threads - 1) / num_threads; + for (int i = 0; i < b_loads_per_thread; ++i) { + int load_idx = tid + i * num_threads; + if (load_idx < TILE_K_F64 * TILE_N) { + int b_row = load_idx / TILE_N; + int b_col = load_idx % TILE_N; + size_t global_row = b_row; + size_t global_col = block_col_start + b_col; + if (global_row < K && global_col < N) { + Bs[0][b_row][b_col] = B[global_row * N + global_col]; + } else { + Bs[0][b_row][b_col] = 0.0; + } + } + } + } + __syncthreads(); + + for (int tile_k = 0; tile_k < num_k_tiles; ++tile_k) { + int next_buf = 1 - curr_buf; + + if (tile_k + 1 < num_k_tiles) { + size_t k_offset = (tile_k + 1) * TILE_K_F64; + + const int a_loads_per_thread = (TILE_M * TILE_K_F64 + num_threads - 1) / num_threads; + for (int i = 0; i < a_loads_per_thread; ++i) { + int load_idx = tid + i * num_threads; + if (load_idx < TILE_M * TILE_K_F64) { + int a_row = load_idx / TILE_K_F64; + int a_col = load_idx % TILE_K_F64; + size_t global_row = block_row_start + a_row; + size_t global_col = k_offset + a_col; + if (global_row < M && global_col < K) { + As[next_buf][a_col][a_row] = A[global_row * K + global_col]; + } else { + As[next_buf][a_col][a_row] = 0.0; + } + } + } + + const int b_loads_per_thread = (TILE_K_F64 * TILE_N + num_threads - 1) / num_threads; + for (int i = 0; i < b_loads_per_thread; ++i) { + int load_idx = tid + i * num_threads; + if (load_idx < TILE_K_F64 * TILE_N) { + int b_row = load_idx / TILE_N; + int b_col = load_idx % TILE_N; + size_t global_row = k_offset + b_row; + size_t global_col = block_col_start + b_col; + if (global_row < K && global_col < N) { + Bs[next_buf][b_row][b_col] = B[global_row * N + global_col]; + } else { + Bs[next_buf][b_row][b_col] = 0.0; + } + } + } + } + + #pragma unroll + for (int k = 0; k < TILE_K_F64; ++k) { + double a_frag[THREAD_M]; + #pragma unroll + for (int m = 0; m < THREAD_M; ++m) { + a_frag[m] = As[curr_buf][k][ty * THREAD_M + m]; + } + + #pragma unroll + for (int n = 0; n < THREAD_N; ++n) { + double b_val = Bs[curr_buf][k][tx * THREAD_N + n]; + #pragma unroll + for (int m = 0; m < THREAD_M; ++m) { + accum[m][n] += a_frag[m] * b_val; + } + } + } + + __syncthreads(); + curr_buf = next_buf; + } + + #pragma unroll + for (int m = 0; m < THREAD_M; ++m) { + size_t out_row = block_row_start + ty * THREAD_M + m; + if (out_row < M) { + #pragma unroll + for (int n = 0; n < THREAD_N; ++n) { + size_t out_col = block_col_start + tx * THREAD_N + n; + if (out_col < N) { + C[out_row * N + out_col] = accum[m][n]; + } + } + } + } +} + +// ============================================================================ +// Launch Helpers +// ============================================================================ + +inline void launch_l2opt_f32(const float* A, const float* B, float* C, size_t M, size_t N, size_t K) { + dim3 block_size(BLOCK_SIZE, BLOCK_SIZE); + dim3 grid_size((N + BLOCK_SIZE - 1) / BLOCK_SIZE, (M + BLOCK_SIZE - 1) / BLOCK_SIZE); + matmul_f32_l2opt_kernel<<>>(A, B, C, M, N, K); +} + +inline void launch_l2opt_f64(const double* A, const double* B, double* C, size_t M, size_t N, size_t K) { + dim3 block_size(BLOCK_SIZE, BLOCK_SIZE); + dim3 grid_size((N + BLOCK_SIZE - 1) / BLOCK_SIZE, (M + BLOCK_SIZE - 1) / BLOCK_SIZE); + matmul_f64_l2opt_kernel<<>>(A, B, C, M, N, K); +} + +inline void launch_tiled_f32(const float* A, const float* B, float* C, size_t M, size_t N, size_t K) { + dim3 block_size(TILE_N / THREAD_N, TILE_M / THREAD_M); + dim3 grid_size((N + TILE_N - 1) / TILE_N, (M + TILE_M - 1) / TILE_M); + matmul_f32_tiled_kernel<<>>(A, B, C, M, N, K); +} + +inline void launch_tiled_f64(const double* A, const double* B, double* C, size_t M, size_t N, size_t K) { + dim3 block_size(TILE_N / THREAD_N, TILE_M / THREAD_M); + dim3 grid_size((N + TILE_N - 1) / TILE_N, (M + TILE_M - 1) / TILE_M); + matmul_f64_tiled_kernel<<>>(A, B, C, M, N, K); +} + +} // namespace matmul_fp32 +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/matmul_cutlass.cuh b/native/ops/matmul_cutlass.cuh new file mode 100644 index 0000000..4c103ff --- /dev/null +++ b/native/ops/matmul_cutlass.cuh @@ -0,0 +1,329 @@ +/** + * CUTLASS-based GEMM kernels for PyGPUkit + * + * Provides high-performance matrix multiplication using NVIDIA CUTLASS library. + * Targets SM 86 (RTX 30 series) with TensorCore support. + * + * Supported dtypes: + * - FP32 (with TF32 TensorCore acceleration) + * - FP16 (native TensorCore) + * - BF16 (native TensorCore) + */ +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/device_memory.h" + +namespace pygpukit { +namespace ops { +namespace cutlass_gemm { + +// ============================================================================ +// TF32 GEMM (FP32 input/output, TF32 TensorCore) +// ============================================================================ + +// TF32 GEMM: FP32 in -> TF32 TensorCore -> FP32 out +// For row-major inputs, use all-ColumnMajor with transpose trick: +// C (M×N row) = A (M×K row) @ B (K×N row) +// becomes: C^T (N×M col) = B^T (N×K col) @ A^T (K×M col) +// where row-major X = col-major X^T in memory +using TF32Gemm = cutlass::gemm::device::Gemm< + float, // ElementA (will be B^T) + cutlass::layout::ColumnMajor, // LayoutA + float, // ElementB (will be A^T) + cutlass::layout::ColumnMajor, // LayoutB + float, // ElementC (will be C^T) + cutlass::layout::ColumnMajor, // LayoutC + float, // ElementAccumulator + cutlass::arch::OpClassTensorOp, // OperatorClass (TensorCore) + cutlass::arch::Sm80, // ArchTag (Ampere) + cutlass::gemm::GemmShape<128, 128, 16>, // ThreadBlockShape + cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape + cutlass::gemm::GemmShape<16, 8, 8>, // InstructionShape (mma.sync) + cutlass::epilogue::thread::LinearCombination< + float, 128 / cutlass::sizeof_bits::value, + float, float>, // EpilogueOp (128-bit aligned) + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 // Stages (pipeline depth) +>; + +// ============================================================================ +// FP16 GEMM (FP16 input/output, FP16 TensorCore) +// ============================================================================ + +// FP16 GEMM with same transpose trick as TF32 (all ColumnMajor) +using FP16Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, // ElementA (will be B^T) + cutlass::layout::ColumnMajor, // LayoutA + cutlass::half_t, // ElementB (will be A^T) + cutlass::layout::ColumnMajor, // LayoutB + cutlass::half_t, // ElementC (will be C^T) + cutlass::layout::ColumnMajor, // LayoutC + float, // ElementAccumulator (FP32 for precision) + cutlass::arch::OpClassTensorOp, // OperatorClass (TensorCore) + cutlass::arch::Sm80, // ArchTag (Ampere) + cutlass::gemm::GemmShape<128, 128, 32>, // ThreadBlockShape + cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape + cutlass::gemm::GemmShape<16, 8, 16>, // InstructionShape (mma.sync.m16n8k16) + cutlass::epilogue::thread::LinearCombination< + cutlass::half_t, 128 / cutlass::sizeof_bits::value, + float, float>, // EpilogueOp (128-bit aligned) + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 // Stages +>; + +// ============================================================================ +// BF16 GEMM (BF16 input/output, BF16 TensorCore) +// ============================================================================ + +// BF16 GEMM with same transpose trick as TF32 (all ColumnMajor) +using BF16Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, // ElementA (will be B^T) + cutlass::layout::ColumnMajor, // LayoutA + cutlass::bfloat16_t, // ElementB (will be A^T) + cutlass::layout::ColumnMajor, // LayoutB + cutlass::bfloat16_t, // ElementC (will be C^T) + cutlass::layout::ColumnMajor, // LayoutC + float, // ElementAccumulator (FP32 for precision) + cutlass::arch::OpClassTensorOp, // OperatorClass (TensorCore) + cutlass::arch::Sm80, // ArchTag (Ampere) + cutlass::gemm::GemmShape<128, 128, 32>, // ThreadBlockShape + cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape + cutlass::gemm::GemmShape<16, 8, 16>, // InstructionShape + cutlass::epilogue::thread::LinearCombination< + cutlass::bfloat16_t, 128 / cutlass::sizeof_bits::value, + float, float>, // EpilogueOp (128-bit aligned) + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 // Stages +>; + +// ============================================================================ +// Wrapper functions +// ============================================================================ + +/** + * TF32 GEMM: C = alpha * A @ B + beta * C + * + * @param A Input matrix A (M x K), row-major, FP32 + * @param B Input matrix B (K x N), row-major, FP32 + * @param C Output matrix C (M x N), row-major, FP32 + * @param M Number of rows in A and C + * @param N Number of columns in B and C + * @param K Number of columns in A and rows in B + * @param alpha Scalar multiplier for A @ B + * @param beta Scalar multiplier for C (set to 0 for C = A @ B) + * @param stream CUDA stream + * @return cudaError_t + * + * Layout trick for row-major inputs with RowMajor×ColumnMajor kernel: + * - CUTLASS kernel: D (M×N row) = A (M×K row) @ B (K×N col) + * - Our inputs: C (M×N row) = A (M×K row) @ B (K×N row) + * + * Key insight: row-major B (K×N) = column-major B^T (N×K) in memory + * + * We compute: C^T (N×M row) = B^T (N×K row) @ A^T (K×M col) + * Which is equivalent to: C (M×N row) = A (M×K row) @ B (K×N row) + * + * For the kernel: + * - M' = N, N' = M, K' = K + * - A' = B^T (N×K row-major), pointer = B, ld = N (stride between rows) + * - B' = A^T (K×M col-major) = A (M×K row-major) in memory, pointer = A, ld = K + * - C' = C^T (N×M row-major), pointer = C, ld = M (stride between rows) + */ +inline cudaError_t gemm_tf32( + const float* A, + const float* B, + float* C, + int M, int N, int K, + float alpha = 1.0f, + float beta = 0.0f, + cudaStream_t stream = nullptr +) { + // Transpose trick for row-major inputs with all-ColumnMajor kernel: + // C (M×N row) = A (M×K row) @ B (K×N row) + // becomes: C^T (N×M col) = B^T (N×K col) @ A^T (K×M col) + // + // Memory equivalence: row-major X (R×C) = col-major X^T (C×R) + // So we reinterpret pointers without copying: + // - B (K×N row) in memory = B^T (N×K col), which is our "A" operand + // - A (M×K row) in memory = A^T (K×M col), which is our "B" operand + // - C (M×N row) in memory = C^T (N×M col), which is our output + // + // problem_size(M', N', K') for output M'×N' = (N, M, K) + cutlass::gemm::GemmCoord problem_size(N, M, K); + + // For column-major matrices, leading dimension = number of rows + // - B^T is N×K col-major, ld = N (num rows) + // - A^T is K×M col-major, ld = K (num rows) + // - C^T is N×M col-major, ld = N (num rows) + typename TF32Gemm::Arguments arguments{ + problem_size, + {B, N}, // "A" operand: B^T (N×K col-major), ld = N + {A, K}, // "B" operand: A^T (K×M col-major), ld = K + {C, N}, // "C" operand: C^T (N×M col-major), ld = N + {C, N}, // D = C + {alpha, beta} // Epilogue params + }; + + TF32Gemm gemm_op; + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return cudaErrorInvalidValue; + } + + size_t workspace_size = TF32Gemm::get_workspace_size(arguments); + + if (workspace_size == 0) { + status = gemm_op.initialize(arguments, nullptr, stream); + } else { + cutlass::device_memory::allocation workspace(workspace_size); + status = gemm_op.initialize(arguments, workspace.get(), stream); + } + + if (status != cutlass::Status::kSuccess) { + return cudaErrorInvalidValue; + } + + status = gemm_op(stream); + if (status != cutlass::Status::kSuccess) { + return cudaErrorInvalidValue; + } + + return cudaSuccess; +} + +/** + * FP16 GEMM: C = alpha * A @ B + beta * C (row-major inputs) + * Uses same transpose trick as TF32 + */ +inline cudaError_t gemm_fp16( + const __half* A, + const __half* B, + __half* C, + int M, int N, int K, + float alpha = 1.0f, + float beta = 0.0f, + cudaStream_t stream = nullptr +) { + // Same transpose trick as TF32: compute C^T = B^T @ A^T + cutlass::gemm::GemmCoord problem_size(N, M, K); + + // Cast to CUTLASS types + const cutlass::half_t* A_cutlass = reinterpret_cast(A); + const cutlass::half_t* B_cutlass = reinterpret_cast(B); + cutlass::half_t* C_cutlass = reinterpret_cast(C); + + // Leading dimensions for col-major transpose trick (ld = num rows) + typename FP16Gemm::Arguments arguments{ + problem_size, + {B_cutlass, N}, // "A" = B^T (N×K col-major), ld = N + {A_cutlass, K}, // "B" = A^T (K×M col-major), ld = K + {C_cutlass, N}, // "C" = C^T (N×M col-major), ld = N + {C_cutlass, N}, // D = C + {alpha, beta} + }; + + FP16Gemm gemm_op; + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return cudaErrorInvalidValue; + } + + size_t workspace_size = FP16Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get(), stream); + if (status != cutlass::Status::kSuccess) { + return cudaErrorInvalidValue; + } + + status = gemm_op(stream); + if (status != cutlass::Status::kSuccess) { + return cudaErrorInvalidValue; + } + + return cudaSuccess; +} + +/** + * BF16 GEMM: C = alpha * A @ B + beta * C (row-major inputs) + * Uses same transpose trick as TF32 + */ +inline cudaError_t gemm_bf16( + const __nv_bfloat16* A, + const __nv_bfloat16* B, + __nv_bfloat16* C, + int M, int N, int K, + float alpha = 1.0f, + float beta = 0.0f, + cudaStream_t stream = nullptr +) { + // Same transpose trick as TF32: compute C^T = B^T @ A^T + cutlass::gemm::GemmCoord problem_size(N, M, K); + + // Cast to CUTLASS types + const cutlass::bfloat16_t* A_cutlass = reinterpret_cast(A); + const cutlass::bfloat16_t* B_cutlass = reinterpret_cast(B); + cutlass::bfloat16_t* C_cutlass = reinterpret_cast(C); + + // Leading dimensions for col-major transpose trick (ld = num rows) + typename BF16Gemm::Arguments arguments{ + problem_size, + {B_cutlass, N}, // "A" = B^T (N×K col-major), ld = N + {A_cutlass, K}, // "B" = A^T (K×M col-major), ld = K + {C_cutlass, N}, // "C" = C^T (N×M col-major), ld = N + {C_cutlass, N}, // D = C + {alpha, beta} + }; + + BF16Gemm gemm_op; + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return cudaErrorInvalidValue; + } + + size_t workspace_size = BF16Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get(), stream); + if (status != cutlass::Status::kSuccess) { + return cudaErrorInvalidValue; + } + + status = gemm_op(stream); + if (status != cutlass::Status::kSuccess) { + return cudaErrorInvalidValue; + } + + return cudaSuccess; +} + +// ============================================================================ +// Dispatch function for runtime dtype selection +// ============================================================================ + +enum class GemmDtype { + FP32_TF32, // FP32 input, TF32 TensorCore + FP16, // FP16 TensorCore + BF16 // BF16 TensorCore +}; + +/** + * Check if matrix dimensions are compatible with CUTLASS TensorCore kernels + * TensorCore requires alignment to tile sizes + */ +inline bool is_cutlass_compatible(int M, int N, int K) { + // Minimum alignment for TensorCore (based on ThreadBlockShape) + // TF32: 128x128x16, FP16/BF16: 128x128x32 + // For simplicity, require 16-alignment on all dimensions + return (M % 16 == 0) && (N % 16 == 0) && (K % 16 == 0); +} + +} // namespace cutlass_gemm +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/matmul_f16_bf16_tc.cuh b/native/ops/matmul_f16_bf16_tc.cuh new file mode 100644 index 0000000..f4141e5 --- /dev/null +++ b/native/ops/matmul_f16_bf16_tc.cuh @@ -0,0 +1,482 @@ +/** + * FP16/BF16 TensorCore Matrix Multiplication + * + * Uses mma.sync.aligned.m16n8k16 for TensorCore acceleration + * - FP16: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + * - BF16: mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 + * + * Both use FP32 accumulation for numerical stability. + * + * Performance target: 50+ TFLOPS on RTX 3090 Ti + */ + +#pragma once +#include +#include +#include + +namespace pygpukit { +namespace ops { +namespace fp16_bf16_tc { + +// Block tile dimensions +constexpr int BM = 128; +constexpr int BN = 128; +constexpr int BK = 32; // K=16 per MMA, 2 MMAs per BK iteration + +// MMA tile dimensions (m16n8k16) +constexpr int MMA_M = 16; +constexpr int MMA_N = 8; +constexpr int MMA_K = 16; + +// Warp configuration +constexpr int WARPS_M = 4; // 4 warps along M +constexpr int WARPS_N = 2; // 2 warps along N +constexpr int WARP_TILES_M = 2; // 2 MMA tiles per warp along M +constexpr int WARP_TILES_N = 8; // 8 MMA tiles per warp along N + +// Padding to avoid bank conflicts +constexpr int A_PAD = 8; +constexpr int B_PAD = 8; + +// ============================================================ +// cp.async helpers +// ============================================================ +__device__ __forceinline__ uint32_t smem_u32(const void* ptr) { + uint32_t addr; + asm volatile( + "{ .reg .u64 smem64; " + " cvta.to.shared.u64 smem64, %1; " + " cvt.u32.u64 %0, smem64; }" + : "=r"(addr) : "l"(ptr) + ); + return addr; +} + +__device__ __forceinline__ void cp_async_16(void* smem, const void* gmem) { + uint32_t addr = smem_u32(smem); + asm volatile( + "cp.async.cg.shared.global [%0], [%1], 16;" + :: "r"(addr), "l"(gmem) + ); +} + +__device__ __forceinline__ void cp_async_commit() { + asm volatile("cp.async.commit_group;"); +} + +__device__ __forceinline__ void cp_async_wait_0() { + asm volatile("cp.async.wait_group 0;"); +} + +// ============================================================ +// FP16 TensorCore GEMM Kernel (FP32 accumulation) +// ============================================================ +__global__ void __launch_bounds__(256, 2) +sgemm_f16_tc_kernel( + const __half* __restrict__ A, + const __half* __restrict__ B, + __half* __restrict__ C, + int M, int N, int K +) { + const int tid = threadIdx.x; + const int warp_id = tid >> 5; + const int lane = tid & 31; + + const int cta_m = blockIdx.y * BM; + const int cta_n = blockIdx.x * BN; + + const int warp_row = warp_id / WARPS_N; + const int warp_col = warp_id % WARPS_N; + + const int warp_m = warp_row * (WARP_TILES_M * MMA_M); + const int warp_n = warp_col * (WARP_TILES_N * MMA_N); + + // Shared memory: store as half + __shared__ __half smA[2][BM][BK + A_PAD]; + __shared__ __half smB[2][BK][BN + B_PAD]; + + // Accumulators (FP32) + float acc[WARP_TILES_M][WARP_TILES_N][4] = {}; + + const int num_k_tiles = K / BK; + + // Fragment index mappings for m16n8k16 + // A fragment: 8 half elements per thread = 4 packed uint32 + // groupID = lane >> 2, threadID_in_group = lane % 4 + const int groupID = lane >> 2; // 0-7 + const int tid_in_group = lane & 3; // 0-3 + + // C fragment mapping (same as TF32 m16n8k8 output) + const int c_row_base = groupID; + const int c_col_base = tid_in_group * 2; + + // ====== cp.async load helpers ====== + auto load_A_async = [&](int stage, int kt) { + // 256 threads, load BM*BK = 128*32 = 4096 halves + // Each thread loads 16 halves = 32 bytes = 2x cp.async_16 + const int elems_per_thread = (BM * BK) / 256; // 16 + const int half_per_load = 8; // cp.async_16 loads 8 halves + + #pragma unroll + for (int i = 0; i < elems_per_thread / half_per_load; ++i) { + int elem_idx = tid * (elems_per_thread / half_per_load) + i; + int row = (elem_idx * half_per_load) / BK; + int col = (elem_idx * half_per_load) % BK; + int gm = cta_m + row; + int gk = kt * BK + col; + if (gm < M && gk + 7 < K) { + cp_async_16(&smA[stage][row][col], &A[gm * K + gk]); + } + } + }; + + auto load_B_async = [&](int stage, int kt) { + // 256 threads, load BK*BN = 32*128 = 4096 halves + const int elems_per_thread = (BK * BN) / 256; // 16 + const int half_per_load = 8; + + #pragma unroll + for (int i = 0; i < elems_per_thread / half_per_load; ++i) { + int elem_idx = tid * (elems_per_thread / half_per_load) + i; + int row = (elem_idx * half_per_load) / BN; + int col = (elem_idx * half_per_load) % BN; + int gk = kt * BK + row; + int gn = cta_n + col; + if (gk < K && gn + 7 < N) { + cp_async_16(&smB[stage][row][col], &B[gk * N + gn]); + } + } + }; + + // ====== Prologue: load first tile ====== + load_A_async(0, 0); + load_B_async(0, 0); + cp_async_commit(); + cp_async_wait_0(); + __syncthreads(); + + // ====== Main loop with double buffering ====== + for (int kt = 0; kt < num_k_tiles; ++kt) { + int curr = kt & 1; + int next = curr ^ 1; + + // Prefetch next tile + if (kt + 1 < num_k_tiles) { + load_A_async(next, kt + 1); + load_B_async(next, kt + 1); + } + cp_async_commit(); + + // Process current tile: 2 MMA iterations per BK (BK=32, MMA_K=16) + #pragma unroll + for (int kk = 0; kk < BK; kk += MMA_K) { + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + int tile_m = warp_m + wm * MMA_M; + + // Load A fragment (8 halves = 4 packed uint32) + // A is 16x16, row-major + // For mma.m16n8k16: + // a[i] where i=0..7 maps to: + // row = groupID + 8 * ((i/2) % 2) + // col = tid_in_group * 2 + (i % 2) + 8 * (i / 4) + uint32_t a_frag[4]; + + // Pack halves into uint32 + // a_frag[0] = (a[1] << 16) | a[0] + // a_frag[1] = (a[3] << 16) | a[2] + // a_frag[2] = (a[5] << 16) | a[4] + // a_frag[3] = (a[7] << 16) | a[6] + + #pragma unroll + for (int p = 0; p < 4; ++p) { + int i0 = p * 2; + int i1 = p * 2 + 1; + + int row0 = groupID + 8 * ((i0 / 2) % 2); + int col0 = tid_in_group * 2 + (i0 % 2) + 8 * (i0 / 4); + int row1 = groupID + 8 * ((i1 / 2) % 2); + int col1 = tid_in_group * 2 + (i1 % 2) + 8 * (i1 / 4); + + __half h0 = smA[curr][tile_m + row0][kk + col0]; + __half h1 = smA[curr][tile_m + row1][kk + col1]; + + // Pack two halves into uint32 + a_frag[p] = __half_as_ushort(h0) | (uint32_t(__half_as_ushort(h1)) << 16); + } + + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_n = warp_n + wn * MMA_N; + + // Load B fragment (4 halves = 2 packed uint32) + // B is 16x8, col-major storage in row-major layout + // For mma.m16n8k16: + // b[i] where i=0..3 maps to: + // row = tid_in_group * 2 + (i % 2) + 8 * (i / 2) + // col = groupID + uint32_t b_frag[2]; + + #pragma unroll + for (int p = 0; p < 2; ++p) { + int i0 = p * 2; + int i1 = p * 2 + 1; + + int row0 = tid_in_group * 2 + (i0 % 2) + 8 * (i0 / 2); + int col0 = groupID; + int row1 = tid_in_group * 2 + (i1 % 2) + 8 * (i1 / 2); + int col1 = groupID; + + __half h0 = smB[curr][kk + row0][tile_n + col0]; + __half h1 = smB[curr][kk + row1][tile_n + col1]; + + b_frag[p] = __half_as_ushort(h0) | (uint32_t(__half_as_ushort(h1)) << 16); + } + + // Execute MMA: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%0, %1, %2, %3};" + : "+f"(acc[wm][wn][0]), "+f"(acc[wm][wn][1]), + "+f"(acc[wm][wn][2]), "+f"(acc[wm][wn][3]) + : "r"(a_frag[0]), "r"(a_frag[1]), + "r"(a_frag[2]), "r"(a_frag[3]), + "r"(b_frag[0]), "r"(b_frag[1]) + ); + } + } + } + + cp_async_wait_0(); + __syncthreads(); + } + + // ====== Epilogue: Store results ====== + // C fragment mapping (16x8): + // c[i] where i=0..3: + // row = groupID + 8 * (i / 2) + // col = tid_in_group * 2 + (i % 2) + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_m = cta_m + warp_m + wm * MMA_M; + int tile_n = cta_n + warp_n + wn * MMA_N; + + int out_row0 = tile_m + c_row_base; + int out_row1 = tile_m + c_row_base + 8; + int out_col0 = tile_n + c_col_base; + int out_col1 = tile_n + c_col_base + 1; + + if (out_row0 < M && out_col0 < N) C[out_row0 * N + out_col0] = __float2half(acc[wm][wn][0]); + if (out_row0 < M && out_col1 < N) C[out_row0 * N + out_col1] = __float2half(acc[wm][wn][1]); + if (out_row1 < M && out_col0 < N) C[out_row1 * N + out_col0] = __float2half(acc[wm][wn][2]); + if (out_row1 < M && out_col1 < N) C[out_row1 * N + out_col1] = __float2half(acc[wm][wn][3]); + } + } +} + +// ============================================================ +// BF16 TensorCore GEMM Kernel (FP32 accumulation) +// ============================================================ +__global__ void __launch_bounds__(256, 2) +sgemm_bf16_tc_kernel( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, + __nv_bfloat16* __restrict__ C, + int M, int N, int K +) { + const int tid = threadIdx.x; + const int warp_id = tid >> 5; + const int lane = tid & 31; + + const int cta_m = blockIdx.y * BM; + const int cta_n = blockIdx.x * BN; + + const int warp_row = warp_id / WARPS_N; + const int warp_col = warp_id % WARPS_N; + + const int warp_m = warp_row * (WARP_TILES_M * MMA_M); + const int warp_n = warp_col * (WARP_TILES_N * MMA_N); + + __shared__ __nv_bfloat16 smA[2][BM][BK + A_PAD]; + __shared__ __nv_bfloat16 smB[2][BK][BN + B_PAD]; + + float acc[WARP_TILES_M][WARP_TILES_N][4] = {}; + + const int num_k_tiles = K / BK; + + const int groupID = lane >> 2; + const int tid_in_group = lane & 3; + const int c_row_base = groupID; + const int c_col_base = tid_in_group * 2; + + auto load_A_async = [&](int stage, int kt) { + const int elems_per_thread = (BM * BK) / 256; + const int half_per_load = 8; + + #pragma unroll + for (int i = 0; i < elems_per_thread / half_per_load; ++i) { + int elem_idx = tid * (elems_per_thread / half_per_load) + i; + int row = (elem_idx * half_per_load) / BK; + int col = (elem_idx * half_per_load) % BK; + int gm = cta_m + row; + int gk = kt * BK + col; + if (gm < M && gk + 7 < K) { + cp_async_16(&smA[stage][row][col], &A[gm * K + gk]); + } + } + }; + + auto load_B_async = [&](int stage, int kt) { + const int elems_per_thread = (BK * BN) / 256; + const int half_per_load = 8; + + #pragma unroll + for (int i = 0; i < elems_per_thread / half_per_load; ++i) { + int elem_idx = tid * (elems_per_thread / half_per_load) + i; + int row = (elem_idx * half_per_load) / BN; + int col = (elem_idx * half_per_load) % BN; + int gk = kt * BK + row; + int gn = cta_n + col; + if (gk < K && gn + 7 < N) { + cp_async_16(&smB[stage][row][col], &B[gk * N + gn]); + } + } + }; + + load_A_async(0, 0); + load_B_async(0, 0); + cp_async_commit(); + cp_async_wait_0(); + __syncthreads(); + + for (int kt = 0; kt < num_k_tiles; ++kt) { + int curr = kt & 1; + int next = curr ^ 1; + + if (kt + 1 < num_k_tiles) { + load_A_async(next, kt + 1); + load_B_async(next, kt + 1); + } + cp_async_commit(); + + #pragma unroll + for (int kk = 0; kk < BK; kk += MMA_K) { + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + int tile_m = warp_m + wm * MMA_M; + + uint32_t a_frag[4]; + + #pragma unroll + for (int p = 0; p < 4; ++p) { + int i0 = p * 2; + int i1 = p * 2 + 1; + + int row0 = groupID + 8 * ((i0 / 2) % 2); + int col0 = tid_in_group * 2 + (i0 % 2) + 8 * (i0 / 4); + int row1 = groupID + 8 * ((i1 / 2) % 2); + int col1 = tid_in_group * 2 + (i1 % 2) + 8 * (i1 / 4); + + __nv_bfloat16 h0 = smA[curr][tile_m + row0][kk + col0]; + __nv_bfloat16 h1 = smA[curr][tile_m + row1][kk + col1]; + + a_frag[p] = __bfloat16_as_ushort(h0) | (uint32_t(__bfloat16_as_ushort(h1)) << 16); + } + + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_n = warp_n + wn * MMA_N; + + uint32_t b_frag[2]; + + #pragma unroll + for (int p = 0; p < 2; ++p) { + int i0 = p * 2; + int i1 = p * 2 + 1; + + int row0 = tid_in_group * 2 + (i0 % 2) + 8 * (i0 / 2); + int col0 = groupID; + int row1 = tid_in_group * 2 + (i1 % 2) + 8 * (i1 / 2); + int col1 = groupID; + + __nv_bfloat16 h0 = smB[curr][kk + row0][tile_n + col0]; + __nv_bfloat16 h1 = smB[curr][kk + row1][tile_n + col1]; + + b_frag[p] = __bfloat16_as_ushort(h0) | (uint32_t(__bfloat16_as_ushort(h1)) << 16); + } + + // Execute MMA: mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%0, %1, %2, %3};" + : "+f"(acc[wm][wn][0]), "+f"(acc[wm][wn][1]), + "+f"(acc[wm][wn][2]), "+f"(acc[wm][wn][3]) + : "r"(a_frag[0]), "r"(a_frag[1]), + "r"(a_frag[2]), "r"(a_frag[3]), + "r"(b_frag[0]), "r"(b_frag[1]) + ); + } + } + } + + cp_async_wait_0(); + __syncthreads(); + } + + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_m = cta_m + warp_m + wm * MMA_M; + int tile_n = cta_n + warp_n + wn * MMA_N; + + int out_row0 = tile_m + c_row_base; + int out_row1 = tile_m + c_row_base + 8; + int out_col0 = tile_n + c_col_base; + int out_col1 = tile_n + c_col_base + 1; + + if (out_row0 < M && out_col0 < N) C[out_row0 * N + out_col0] = __float2bfloat16_rn(acc[wm][wn][0]); + if (out_row0 < M && out_col1 < N) C[out_row0 * N + out_col1] = __float2bfloat16_rn(acc[wm][wn][1]); + if (out_row1 < M && out_col0 < N) C[out_row1 * N + out_col0] = __float2bfloat16_rn(acc[wm][wn][2]); + if (out_row1 < M && out_col1 < N) C[out_row1 * N + out_col1] = __float2bfloat16_rn(acc[wm][wn][3]); + } + } +} + +// ============================================================ +// Launch functions +// ============================================================ +inline cudaError_t launch_sgemm_f16_tc( + const __half* A, const __half* B, __half* C, + int M, int N, int K, + cudaStream_t stream = 0 +) { + dim3 block(256); + dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM); + sgemm_f16_tc_kernel<<>>(A, B, C, M, N, K); + return cudaGetLastError(); +} + +inline cudaError_t launch_sgemm_bf16_tc( + const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C, + int M, int N, int K, + cudaStream_t stream = 0 +) { + dim3 block(256); + dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM); + sgemm_bf16_tc_kernel<<>>(A, B, C, M, N, K); + return cudaGetLastError(); +} + +} // namespace fp16_bf16_tc +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/matmul_f16_bf16_tc_generic.cuh b/native/ops/matmul_f16_bf16_tc_generic.cuh new file mode 100644 index 0000000..0dd1738 --- /dev/null +++ b/native/ops/matmul_f16_bf16_tc_generic.cuh @@ -0,0 +1,424 @@ +/** + * FP16/BF16 TensorCore Generic GEMM (with boundary handling) + * + * Supports arbitrary matrix sizes with M,N >= 16 and K % 8 == 0 + * Uses mma.sync.aligned.m16n8k8 for flexibility + * + * Trade-off: Slightly slower than TC_FAST due to boundary checks, + * but supports many more matrix sizes. + */ + +#pragma once +#include +#include +#include + +namespace pygpukit { +namespace ops { +namespace fp16_bf16_tc_generic { + +// Smaller tile for better flexibility +constexpr int BM = 64; +constexpr int BN = 64; +constexpr int BK = 8; // Match MMA K dimension + +// MMA tile dimensions (m16n8k8 for FP16) +constexpr int MMA_M = 16; +constexpr int MMA_N = 8; +constexpr int MMA_K = 8; + +// Warp configuration (4 warps = 128 threads) +constexpr int WARPS_M = 2; // 2 warps along M +constexpr int WARPS_N = 2; // 2 warps along N +constexpr int WARP_TILES_M = 2; // 2 MMA tiles per warp along M (32 rows) +constexpr int WARP_TILES_N = 4; // 4 MMA tiles per warp along N (32 cols) + +// Padding for bank conflict avoidance +constexpr int A_PAD = 8; +constexpr int B_PAD = 8; + +// ============================================================ +// Helpers +// ============================================================ +__device__ __forceinline__ uint32_t smem_u32_generic(const void* ptr) { + uint32_t addr; + asm volatile( + "{ .reg .u64 smem64; " + " cvta.to.shared.u64 smem64, %1; " + " cvt.u32.u64 %0, smem64; }" + : "=r"(addr) : "l"(ptr) + ); + return addr; +} + +__device__ __forceinline__ void cp_async_16_generic(void* smem, const void* gmem) { + uint32_t addr = smem_u32_generic(smem); + asm volatile( + "cp.async.cg.shared.global [%0], [%1], 16;" + :: "r"(addr), "l"(gmem) + ); +} + +__device__ __forceinline__ void cp_async_commit_generic() { + asm volatile("cp.async.commit_group;"); +} + +__device__ __forceinline__ void cp_async_wait_0_generic() { + asm volatile("cp.async.wait_group 0;"); +} + +// ============================================================ +// FP16 TensorCore Generic GEMM Kernel +// ============================================================ +__global__ void __launch_bounds__(128, 4) +sgemm_f16_tc_generic_kernel( + const __half* __restrict__ A, + const __half* __restrict__ B, + __half* __restrict__ C, + int M, int N, int K +) { + const int tid = threadIdx.x; + const int warp_id = tid >> 5; + const int lane = tid & 31; + + const int cta_m = blockIdx.y * BM; + const int cta_n = blockIdx.x * BN; + + const int warp_row = warp_id / WARPS_N; + const int warp_col = warp_id % WARPS_N; + + const int warp_m = warp_row * (WARP_TILES_M * MMA_M); + const int warp_n = warp_col * (WARP_TILES_N * MMA_N); + + // Shared memory + __shared__ __half smA[BM][BK + A_PAD]; + __shared__ __half smB[BK][BN + B_PAD]; + + // Accumulators (FP32) + float acc[WARP_TILES_M][WARP_TILES_N][4] = {}; + + const int num_k_tiles = (K + BK - 1) / BK; + + // Fragment index mappings for m16n8k8 + const int groupID = lane >> 2; // 0-7 + const int tid_in_group = lane & 3; // 0-3 + + // C fragment mapping + const int c_row_base = groupID; + const int c_col_base = tid_in_group * 2; + + // ====== Main loop ====== + for (int kt = 0; kt < num_k_tiles; ++kt) { + // Load A tile with boundary check + // 128 threads, BM*BK = 64*8 = 512 halves = 4 per thread + { + const int elems_per_thread = 4; + #pragma unroll + for (int i = 0; i < elems_per_thread; ++i) { + int idx = tid * elems_per_thread + i; + int row = idx / BK; + int col = idx % BK; + int gm = cta_m + row; + int gk = kt * BK + col; + + __half val = __float2half(0.0f); + if (gm < M && gk < K) { + val = A[gm * K + gk]; + } + smA[row][col] = val; + } + } + + // Load B tile with boundary check + // BK*BN = 8*64 = 512 halves = 4 per thread + { + const int elems_per_thread = 4; + #pragma unroll + for (int i = 0; i < elems_per_thread; ++i) { + int idx = tid * elems_per_thread + i; + int row = idx / BN; + int col = idx % BN; + int gk = kt * BK + row; + int gn = cta_n + col; + + __half val = __float2half(0.0f); + if (gk < K && gn < N) { + val = B[gk * N + gn]; + } + smB[row][col] = val; + } + } + + __syncthreads(); + + // Compute MMA for this K tile (single k iteration since BK == MMA_K) + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + int tile_m = warp_m + wm * MMA_M; + + // Load A fragment (4 halves = 2 packed uint32 for m16n8k8) + // m16n8k8 A fragment: 4 registers + // a[i] for i=0..3: + // row = groupID + 8 * (i / 2) + // col = tid_in_group * 2 + (i % 2) + uint32_t a_frag[2]; + + #pragma unroll + for (int p = 0; p < 2; ++p) { + int i0 = p * 2; + int i1 = p * 2 + 1; + + int row0 = groupID + 8 * (i0 / 2); + int col0 = tid_in_group * 2 + (i0 % 2); + int row1 = groupID + 8 * (i1 / 2); + int col1 = tid_in_group * 2 + (i1 % 2); + + __half h0 = (tile_m + row0 < BM) ? smA[tile_m + row0][col0] : __float2half(0.0f); + __half h1 = (tile_m + row1 < BM) ? smA[tile_m + row1][col1] : __float2half(0.0f); + + a_frag[p] = __half_as_ushort(h0) | (uint32_t(__half_as_ushort(h1)) << 16); + } + + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_n = warp_n + wn * MMA_N; + + // Load B fragment (2 halves = 1 packed uint32 for m16n8k8) + // m16n8k8 B fragment: 2 registers + // b[i] for i=0..1: + // row = tid_in_group * 2 + (i % 2) + // col = groupID + uint32_t b_frag; + + { + int row0 = tid_in_group * 2; + int row1 = tid_in_group * 2 + 1; + int col = groupID; + + __half h0 = (row0 < BK && tile_n + col < BN) ? smB[row0][tile_n + col] : __float2half(0.0f); + __half h1 = (row1 < BK && tile_n + col < BN) ? smB[row1][tile_n + col] : __float2half(0.0f); + + b_frag = __half_as_ushort(h0) | (uint32_t(__half_as_ushort(h1)) << 16); + } + + // Execute MMA: mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5}, " + "{%6}, " + "{%0, %1, %2, %3};" + : "+f"(acc[wm][wn][0]), "+f"(acc[wm][wn][1]), + "+f"(acc[wm][wn][2]), "+f"(acc[wm][wn][3]) + : "r"(a_frag[0]), "r"(a_frag[1]), + "r"(b_frag) + ); + } + } + + __syncthreads(); + } + + // ====== Epilogue: Store results with boundary check ====== + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_m = cta_m + warp_m + wm * MMA_M; + int tile_n = cta_n + warp_n + wn * MMA_N; + + int out_row0 = tile_m + c_row_base; + int out_row1 = tile_m + c_row_base + 8; + int out_col0 = tile_n + c_col_base; + int out_col1 = tile_n + c_col_base + 1; + + if (out_row0 < M && out_col0 < N) C[out_row0 * N + out_col0] = __float2half(acc[wm][wn][0]); + if (out_row0 < M && out_col1 < N) C[out_row0 * N + out_col1] = __float2half(acc[wm][wn][1]); + if (out_row1 < M && out_col0 < N) C[out_row1 * N + out_col0] = __float2half(acc[wm][wn][2]); + if (out_row1 < M && out_col1 < N) C[out_row1 * N + out_col1] = __float2half(acc[wm][wn][3]); + } + } +} + +// ============================================================ +// BF16 TensorCore Generic GEMM Kernel +// ============================================================ +__global__ void __launch_bounds__(128, 4) +sgemm_bf16_tc_generic_kernel( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, + __nv_bfloat16* __restrict__ C, + int M, int N, int K +) { + const int tid = threadIdx.x; + const int warp_id = tid >> 5; + const int lane = tid & 31; + + const int cta_m = blockIdx.y * BM; + const int cta_n = blockIdx.x * BN; + + const int warp_row = warp_id / WARPS_N; + const int warp_col = warp_id % WARPS_N; + + const int warp_m = warp_row * (WARP_TILES_M * MMA_M); + const int warp_n = warp_col * (WARP_TILES_N * MMA_N); + + __shared__ __nv_bfloat16 smA[BM][BK + A_PAD]; + __shared__ __nv_bfloat16 smB[BK][BN + B_PAD]; + + float acc[WARP_TILES_M][WARP_TILES_N][4] = {}; + + const int num_k_tiles = (K + BK - 1) / BK; + + const int groupID = lane >> 2; + const int tid_in_group = lane & 3; + const int c_row_base = groupID; + const int c_col_base = tid_in_group * 2; + + for (int kt = 0; kt < num_k_tiles; ++kt) { + // Load A tile + { + const int elems_per_thread = 4; + #pragma unroll + for (int i = 0; i < elems_per_thread; ++i) { + int idx = tid * elems_per_thread + i; + int row = idx / BK; + int col = idx % BK; + int gm = cta_m + row; + int gk = kt * BK + col; + + __nv_bfloat16 val = __float2bfloat16_rn(0.0f); + if (gm < M && gk < K) { + val = A[gm * K + gk]; + } + smA[row][col] = val; + } + } + + // Load B tile + { + const int elems_per_thread = 4; + #pragma unroll + for (int i = 0; i < elems_per_thread; ++i) { + int idx = tid * elems_per_thread + i; + int row = idx / BN; + int col = idx % BN; + int gk = kt * BK + row; + int gn = cta_n + col; + + __nv_bfloat16 val = __float2bfloat16_rn(0.0f); + if (gk < K && gn < N) { + val = B[gk * N + gn]; + } + smB[row][col] = val; + } + } + + __syncthreads(); + + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + int tile_m = warp_m + wm * MMA_M; + + uint32_t a_frag[2]; + + #pragma unroll + for (int p = 0; p < 2; ++p) { + int i0 = p * 2; + int i1 = p * 2 + 1; + + int row0 = groupID + 8 * (i0 / 2); + int col0 = tid_in_group * 2 + (i0 % 2); + int row1 = groupID + 8 * (i1 / 2); + int col1 = tid_in_group * 2 + (i1 % 2); + + __nv_bfloat16 h0 = (tile_m + row0 < BM) ? smA[tile_m + row0][col0] : __float2bfloat16_rn(0.0f); + __nv_bfloat16 h1 = (tile_m + row1 < BM) ? smA[tile_m + row1][col1] : __float2bfloat16_rn(0.0f); + + a_frag[p] = __bfloat16_as_ushort(h0) | (uint32_t(__bfloat16_as_ushort(h1)) << 16); + } + + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_n = warp_n + wn * MMA_N; + + uint32_t b_frag; + + { + int row0 = tid_in_group * 2; + int row1 = tid_in_group * 2 + 1; + int col = groupID; + + __nv_bfloat16 h0 = (row0 < BK && tile_n + col < BN) ? smB[row0][tile_n + col] : __float2bfloat16_rn(0.0f); + __nv_bfloat16 h1 = (row1 < BK && tile_n + col < BN) ? smB[row1][tile_n + col] : __float2bfloat16_rn(0.0f); + + b_frag = __bfloat16_as_ushort(h0) | (uint32_t(__bfloat16_as_ushort(h1)) << 16); + } + + // Execute MMA: mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5}, " + "{%6}, " + "{%0, %1, %2, %3};" + : "+f"(acc[wm][wn][0]), "+f"(acc[wm][wn][1]), + "+f"(acc[wm][wn][2]), "+f"(acc[wm][wn][3]) + : "r"(a_frag[0]), "r"(a_frag[1]), + "r"(b_frag) + ); + } + } + + __syncthreads(); + } + + // Epilogue + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_m = cta_m + warp_m + wm * MMA_M; + int tile_n = cta_n + warp_n + wn * MMA_N; + + int out_row0 = tile_m + c_row_base; + int out_row1 = tile_m + c_row_base + 8; + int out_col0 = tile_n + c_col_base; + int out_col1 = tile_n + c_col_base + 1; + + if (out_row0 < M && out_col0 < N) C[out_row0 * N + out_col0] = __float2bfloat16_rn(acc[wm][wn][0]); + if (out_row0 < M && out_col1 < N) C[out_row0 * N + out_col1] = __float2bfloat16_rn(acc[wm][wn][1]); + if (out_row1 < M && out_col0 < N) C[out_row1 * N + out_col0] = __float2bfloat16_rn(acc[wm][wn][2]); + if (out_row1 < M && out_col1 < N) C[out_row1 * N + out_col1] = __float2bfloat16_rn(acc[wm][wn][3]); + } + } +} + +// ============================================================ +// Launch functions +// ============================================================ +inline cudaError_t launch_sgemm_f16_tc_generic( + const __half* A, const __half* B, __half* C, + int M, int N, int K, + cudaStream_t stream = 0 +) { + dim3 block(128); // 4 warps + dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM); + sgemm_f16_tc_generic_kernel<<>>(A, B, C, M, N, K); + return cudaGetLastError(); +} + +inline cudaError_t launch_sgemm_bf16_tc_generic( + const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C, + int M, int N, int K, + cudaStream_t stream = 0 +) { + dim3 block(128); // 4 warps + dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM); + sgemm_bf16_tc_generic_kernel<<>>(A, B, C, M, N, K); + return cudaGetLastError(); +} + +} // namespace fp16_bf16_tc_generic +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/nn/nn.cu b/native/ops/nn/nn.cu new file mode 100644 index 0000000..d6ecebb --- /dev/null +++ b/native/ops/nn/nn.cu @@ -0,0 +1,311 @@ +/** + * Neural Network operations dispatch + */ +#include "nn_kernels.cuh" +#include "../common/error.cuh" +#include "../../core/memory.hpp" +#include + +namespace pygpukit { +namespace ops { + +using namespace nn; + +// ============================================================================ +// GELU Activation +// ============================================================================ + +GPUArray gelu(const GPUArray& input) { + if (input.dtype() != DataType::Float32 && input.dtype() != DataType::Float64 && + input.dtype() != DataType::Float16 && input.dtype() != DataType::BFloat16) { + throw std::runtime_error("gelu only supports float types"); + } + + GPUArray result(input.shape(), input.dtype()); + size_t n = input.size(); + + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (input.dtype()) { + case DataType::Float32: + gelu_f32_kernel<<>>( + static_cast(input.data()), + static_cast(result.data()), + n); + break; + case DataType::Float64: + gelu_f64_kernel<<>>( + static_cast(input.data()), + static_cast(result.data()), + n); + break; + case DataType::Float16: + gelu_f16_kernel<<>>( + static_cast(input.data()), + static_cast<__half*>(result.data()), + n); + break; + case DataType::BFloat16: + gelu_bf16_kernel<<>>( + static_cast(input.data()), + static_cast<__nv_bfloat16*>(result.data()), + n); + break; + default: + break; + } + + sync_and_check("gelu kernel failed"); + return result; +} + +// ============================================================================ +// Bias Add +// ============================================================================ + +// In-place bias add: output[batch, features] += bias[features] +void bias_add_inplace(GPUArray& output, const GPUArray& bias) { + if (output.ndim() != 2) { + throw std::runtime_error("bias_add expects 2D output tensor [batch, features]"); + } + if (bias.ndim() != 1) { + throw std::runtime_error("bias_add expects 1D bias tensor [features]"); + } + if (output.dtype() != bias.dtype()) { + throw std::runtime_error("bias_add: dtype mismatch"); + } + + size_t batch_size = output.shape()[0]; + size_t features = output.shape()[1]; + + if (bias.shape()[0] != features) { + throw std::runtime_error("bias_add: bias size must match output features"); + } + + size_t n = batch_size * features; + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (output.dtype()) { + case DataType::Float32: + bias_add_f32_kernel<<>>( + static_cast(output.data()), + static_cast(bias.data()), + batch_size, features); + break; + case DataType::Float64: + bias_add_f64_kernel<<>>( + static_cast(output.data()), + static_cast(bias.data()), + batch_size, features); + break; + case DataType::Float16: + bias_add_f16_kernel<<>>( + static_cast<__half*>(output.data()), + static_cast(bias.data()), + batch_size, features); + break; + case DataType::BFloat16: + bias_add_bf16_kernel<<>>( + static_cast<__nv_bfloat16*>(output.data()), + static_cast(bias.data()), + batch_size, features); + break; + default: + throw std::runtime_error("bias_add only supports float types"); + } + + sync_and_check("bias_add kernel failed"); +} + +// ============================================================================ +// Linear Layer: y = xW^T + b +// ============================================================================ + +GPUArray linear(const GPUArray& input, const GPUArray& weight, const GPUArray* bias) { + // input: [batch, in_features] + // weight: [out_features, in_features] + // output: [batch, out_features] + + if (input.ndim() != 2) { + throw std::runtime_error("linear expects 2D input [batch, in_features]"); + } + if (weight.ndim() != 2) { + throw std::runtime_error("linear expects 2D weight [out_features, in_features]"); + } + if (input.dtype() != weight.dtype()) { + throw std::runtime_error("linear: input and weight dtype mismatch"); + } + + size_t batch = input.shape()[0]; + size_t in_features = input.shape()[1]; + size_t out_features = weight.shape()[0]; + + if (weight.shape()[1] != in_features) { + throw std::runtime_error("linear: weight in_features must match input"); + } + + // Compute y = x @ W^T using matmul with transposed weight + // For now, we'll transpose weight and use matmul + // TODO: Add transpose operation or use cuBLAS GEMM directly + + // Create transposed weight [in_features, out_features] + GPUArray weight_t({in_features, out_features}, weight.dtype()); + + // Simple transpose kernel + // For MVP, we'll just do matmul(input, weight.T) + // This requires a transpose, which we'll implement inline + + // Launch transpose kernel (simple 2D transpose) + const int block_dim = 16; + dim3 block(block_dim, block_dim); + dim3 grid((out_features + block_dim - 1) / block_dim, + (in_features + block_dim - 1) / block_dim); + + // Inline transpose kernel launch + auto transpose_f32 = [](const float* src, float* dst, int rows, int cols, dim3 grid, dim3 block) { + // Simple element-wise transpose + struct TransposeArgs { const float* src; float* dst; int rows; int cols; }; + // Use a lambda kernel via NVRTC would be ideal, but for now use a simple loop + // This is temporary - proper transpose kernel should be in a separate file + }; + + // For MVP: use row-major matmul and handle transpose in a simple way + // Actually, let's use the fact that (A @ B.T) = (B @ A.T).T for some cases + // Or better: just implement it directly with cuBLAS-style GEMM semantics + + // Simplest approach for MVP: copy weight transposed element-by-element on host + // This is slow but correct for small models like GPT-2 + + // For now, compute output = input @ weight^T directly using our matmul + // Our matmul does C = A @ B where A is MxK, B is KxN, C is MxN + // We need: output = input @ weight^T + // input: [batch, in_features] = [M, K] + // weight: [out_features, in_features] = [N, K] + // weight^T: [in_features, out_features] = [K, N] + // output: [batch, out_features] = [M, N] + + // So we need to transpose weight first + // For MVP, let's assume weight is stored as [out_features, in_features] + // and we need [in_features, out_features] + + // Actually, the simplest MVP is to use a different matmul signature + // that handles transposed B directly. For now, let's just do naive CPU transpose. + + // Even simpler: for MVP, assume weight is already in the right layout + // or do the computation via multiple kernels + + // Let's do: output = matmul(input, weight_transposed) + // where we transpose weight on GPU using a simple kernel + + // For GPT-2 small: in_features = 768, out_features = 768 or 3072 + // This is manageable + + // Create result first + GPUArray result({batch, out_features}, input.dtype()); + + // For MVP: use matmul with transposed semantics + // We'll add a transposed matmul later, for now do element-wise transpose + + // Temporary: use internal matmul that can handle transpose + // Our existing matmul assumes row-major A @ B + // We need A @ B^T which is equivalent to (B @ A^T)^T + + // Simplest solution: call cuBLAS-style GEMM + // For now, let's implement a simple transpose + matmul + + // Skip bias for now in basic implementation + (void)bias; + + // For MVP, return a placeholder that works for small matrices + // Real implementation would use optimized transpose + matmul + + // Actually, let's make this work by noting: + // C[i,j] = sum_k A[i,k] * B[k,j] (normal matmul) + // We want: C[i,j] = sum_k A[i,k] * W[j,k] (matmul with transposed W) + // This is GEMM with transB = true + + // Our current matmul is C = A @ B (both row-major) + // We need C = A @ B^T + + // Let's add this capability to our matmul + + throw std::runtime_error("linear: not yet implemented - use matmul + bias_add separately for MVP"); +} + +// ============================================================================ +// LayerNorm +// ============================================================================ + +GPUArray layernorm(const GPUArray& input, const GPUArray& gamma, const GPUArray& beta, float eps) { + // input: [batch, features] + // gamma: [features] + // beta: [features] + + if (input.ndim() != 2) { + throw std::runtime_error("layernorm expects 2D input [batch, features]"); + } + if (gamma.ndim() != 1 || beta.ndim() != 1) { + throw std::runtime_error("layernorm expects 1D gamma and beta"); + } + if (input.dtype() != gamma.dtype() || input.dtype() != beta.dtype()) { + throw std::runtime_error("layernorm: dtype mismatch"); + } + + size_t batch_size = input.shape()[0]; + size_t features = input.shape()[1]; + + if (gamma.shape()[0] != features || beta.shape()[0] != features) { + throw std::runtime_error("layernorm: gamma/beta size must match features"); + } + + GPUArray result(input.shape(), input.dtype()); + + // One block per row, use enough threads to cover features + int block_size = std::min(256, (int)((features + 31) / 32 * 32)); + block_size = std::max(32, block_size); + + switch (input.dtype()) { + case DataType::Float32: + layernorm_f32_kernel<<>>( + static_cast(input.data()), + static_cast(gamma.data()), + static_cast(beta.data()), + static_cast(result.data()), + batch_size, features, eps); + break; + case DataType::Float64: + layernorm_f64_kernel<<>>( + static_cast(input.data()), + static_cast(gamma.data()), + static_cast(beta.data()), + static_cast(result.data()), + batch_size, features, (double)eps); + break; + case DataType::Float16: + layernorm_f16_kernel<<>>( + static_cast(input.data()), + static_cast(gamma.data()), + static_cast(beta.data()), + static_cast<__half*>(result.data()), + batch_size, features, eps); + break; + case DataType::BFloat16: + layernorm_bf16_kernel<<>>( + static_cast(input.data()), + static_cast(gamma.data()), + static_cast(beta.data()), + static_cast<__nv_bfloat16*>(result.data()), + batch_size, features, eps); + break; + default: + throw std::runtime_error("layernorm only supports float types"); + } + + sync_and_check("layernorm kernel failed"); + return result; +} + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/nn/nn_kernels.cuh b/native/ops/nn/nn_kernels.cuh new file mode 100644 index 0000000..9c6b0d1 --- /dev/null +++ b/native/ops/nn/nn_kernels.cuh @@ -0,0 +1,482 @@ +/** + * Neural Network operation kernels + * + * Provides: Linear (matmul + bias), LayerNorm, GELU + */ +#pragma once + +#include +#include +#include +#include + +namespace pygpukit { +namespace ops { +namespace nn { + +// ============================================================================ +// GELU Activation +// ============================================================================ + +// GELU approximation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) +// tanh-based approximation (faster, close to exact) +__device__ __forceinline__ float gelu_f32(float x) { + const float c1 = 0.7978845608f; // sqrt(2/pi) + const float c2 = 0.044715f; + float x3 = x * x * x; + return x * 0.5f * (1.0f + tanhf(c1 * (x + c2 * x3))); +} + +__device__ __forceinline__ double gelu_f64(double x) { + const double c1 = 0.7978845608028654; // sqrt(2/pi) + const double c2 = 0.044715; + double x3 = x * x * x; + return x * 0.5 * (1.0 + tanh(c1 * (x + c2 * x3))); +} + +__global__ void gelu_f32_kernel(const float* __restrict__ input, + float* __restrict__ output, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = gelu_f32(input[idx]); + } +} + +__global__ void gelu_f64_kernel(const double* __restrict__ input, + double* __restrict__ output, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = gelu_f64(input[idx]); + } +} + +__global__ void gelu_f16_kernel(const __half* __restrict__ input, + __half* __restrict__ output, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __half2float(input[idx]); + output[idx] = __float2half(gelu_f32(x)); + } +} + +__global__ void gelu_bf16_kernel(const __nv_bfloat16* __restrict__ input, + __nv_bfloat16* __restrict__ output, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __bfloat162float(input[idx]); + output[idx] = __float2bfloat16(gelu_f32(x)); + } +} + +// ============================================================================ +// Bias Add (for Linear layer: y = Wx + b) +// ============================================================================ + +// Add bias to each row of output [batch, features] +// output[i,j] += bias[j] +__global__ void bias_add_f32_kernel(float* __restrict__ output, + const float* __restrict__ bias, + size_t batch_size, + size_t features) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < batch_size * features) { + size_t j = idx % features; + output[idx] += bias[j]; + } +} + +__global__ void bias_add_f64_kernel(double* __restrict__ output, + const double* __restrict__ bias, + size_t batch_size, + size_t features) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < batch_size * features) { + size_t j = idx % features; + output[idx] += bias[j]; + } +} + +__global__ void bias_add_f16_kernel(__half* __restrict__ output, + const __half* __restrict__ bias, + size_t batch_size, + size_t features) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < batch_size * features) { + size_t j = idx % features; + float out_val = __half2float(output[idx]); + float bias_val = __half2float(bias[j]); + output[idx] = __float2half(out_val + bias_val); + } +} + +__global__ void bias_add_bf16_kernel(__nv_bfloat16* __restrict__ output, + const __nv_bfloat16* __restrict__ bias, + size_t batch_size, + size_t features) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < batch_size * features) { + size_t j = idx % features; + float out_val = __bfloat162float(output[idx]); + float bias_val = __bfloat162float(bias[j]); + output[idx] = __float2bfloat16(out_val + bias_val); + } +} + +// ============================================================================ +// LayerNorm +// ============================================================================ + +// Layer normalization: y = (x - mean) / sqrt(var + eps) * gamma + beta +// Input: [batch, features], normalize over features dimension + +// Single-pass mean and variance using Welford's algorithm +__device__ __forceinline__ void welford_update(float& mean, float& m2, float val, int count) { + float delta = val - mean; + mean += delta / count; + float delta2 = val - mean; + m2 += delta * delta2; +} + +// LayerNorm kernel - one warp per row for small feature sizes +__global__ void layernorm_f32_kernel(const float* __restrict__ input, + const float* __restrict__ gamma, + const float* __restrict__ beta, + float* __restrict__ output, + size_t batch_size, + size_t features, + float eps) { + int row = blockIdx.x; + if (row >= batch_size) return; + + const float* row_input = input + row * features; + float* row_output = output + row * features; + + // Compute mean using parallel reduction + float sum = 0.0f; + for (int i = threadIdx.x; i < features; i += blockDim.x) { + sum += row_input[i]; + } + + // Warp-level reduction + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + + // Block-level reduction using shared memory + __shared__ float shared_sum[32]; // Max 32 warps + int lane = threadIdx.x % warpSize; + int warp_id = threadIdx.x / warpSize; + + if (lane == 0) { + shared_sum[warp_id] = sum; + } + __syncthreads(); + + // First warp reduces across warps + if (warp_id == 0) { + sum = (threadIdx.x < (blockDim.x + warpSize - 1) / warpSize) ? shared_sum[threadIdx.x] : 0.0f; + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + } + + __shared__ float mean; + if (threadIdx.x == 0) { + mean = sum / features; + } + __syncthreads(); + + // Compute variance + float var_sum = 0.0f; + for (int i = threadIdx.x; i < features; i += blockDim.x) { + float diff = row_input[i] - mean; + var_sum += diff * diff; + } + + // Warp reduction for variance + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + var_sum += __shfl_down_sync(0xffffffff, var_sum, offset); + } + + if (lane == 0) { + shared_sum[warp_id] = var_sum; + } + __syncthreads(); + + if (warp_id == 0) { + var_sum = (threadIdx.x < (blockDim.x + warpSize - 1) / warpSize) ? shared_sum[threadIdx.x] : 0.0f; + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + var_sum += __shfl_down_sync(0xffffffff, var_sum, offset); + } + } + + __shared__ float inv_std; + if (threadIdx.x == 0) { + inv_std = rsqrtf(var_sum / features + eps); + } + __syncthreads(); + + // Normalize and apply affine transform + for (int i = threadIdx.x; i < features; i += blockDim.x) { + float x = row_input[i]; + float normalized = (x - mean) * inv_std; + row_output[i] = normalized * gamma[i] + beta[i]; + } +} + +// Double precision LayerNorm +__global__ void layernorm_f64_kernel(const double* __restrict__ input, + const double* __restrict__ gamma, + const double* __restrict__ beta, + double* __restrict__ output, + size_t batch_size, + size_t features, + double eps) { + int row = blockIdx.x; + if (row >= batch_size) return; + + const double* row_input = input + row * features; + double* row_output = output + row * features; + + // Compute mean + double sum = 0.0; + for (int i = threadIdx.x; i < features; i += blockDim.x) { + sum += row_input[i]; + } + + // Warp-level reduction + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + + __shared__ double shared_sum[32]; + int lane = threadIdx.x % warpSize; + int warp_id = threadIdx.x / warpSize; + + if (lane == 0) { + shared_sum[warp_id] = sum; + } + __syncthreads(); + + if (warp_id == 0) { + sum = (threadIdx.x < (blockDim.x + warpSize - 1) / warpSize) ? shared_sum[threadIdx.x] : 0.0; + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + } + + __shared__ double mean; + if (threadIdx.x == 0) { + mean = sum / features; + } + __syncthreads(); + + // Compute variance + double var_sum = 0.0; + for (int i = threadIdx.x; i < features; i += blockDim.x) { + double diff = row_input[i] - mean; + var_sum += diff * diff; + } + + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + var_sum += __shfl_down_sync(0xffffffff, var_sum, offset); + } + + if (lane == 0) { + shared_sum[warp_id] = var_sum; + } + __syncthreads(); + + if (warp_id == 0) { + var_sum = (threadIdx.x < (blockDim.x + warpSize - 1) / warpSize) ? shared_sum[threadIdx.x] : 0.0; + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + var_sum += __shfl_down_sync(0xffffffff, var_sum, offset); + } + } + + __shared__ double inv_std; + if (threadIdx.x == 0) { + inv_std = rsqrt(var_sum / features + eps); + } + __syncthreads(); + + // Normalize and apply affine transform + for (int i = threadIdx.x; i < features; i += blockDim.x) { + double x = row_input[i]; + double normalized = (x - mean) * inv_std; + row_output[i] = normalized * gamma[i] + beta[i]; + } +} + +// FP16 LayerNorm (compute in FP32 for precision) +__global__ void layernorm_f16_kernel(const __half* __restrict__ input, + const __half* __restrict__ gamma, + const __half* __restrict__ beta, + __half* __restrict__ output, + size_t batch_size, + size_t features, + float eps) { + int row = blockIdx.x; + if (row >= batch_size) return; + + const __half* row_input = input + row * features; + __half* row_output = output + row * features; + + // Compute mean in FP32 + float sum = 0.0f; + for (int i = threadIdx.x; i < features; i += blockDim.x) { + sum += __half2float(row_input[i]); + } + + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + + __shared__ float shared_sum[32]; + int lane = threadIdx.x % warpSize; + int warp_id = threadIdx.x / warpSize; + + if (lane == 0) { + shared_sum[warp_id] = sum; + } + __syncthreads(); + + if (warp_id == 0) { + sum = (threadIdx.x < (blockDim.x + warpSize - 1) / warpSize) ? shared_sum[threadIdx.x] : 0.0f; + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + } + + __shared__ float mean; + if (threadIdx.x == 0) { + mean = sum / features; + } + __syncthreads(); + + float var_sum = 0.0f; + for (int i = threadIdx.x; i < features; i += blockDim.x) { + float diff = __half2float(row_input[i]) - mean; + var_sum += diff * diff; + } + + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + var_sum += __shfl_down_sync(0xffffffff, var_sum, offset); + } + + if (lane == 0) { + shared_sum[warp_id] = var_sum; + } + __syncthreads(); + + if (warp_id == 0) { + var_sum = (threadIdx.x < (blockDim.x + warpSize - 1) / warpSize) ? shared_sum[threadIdx.x] : 0.0f; + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + var_sum += __shfl_down_sync(0xffffffff, var_sum, offset); + } + } + + __shared__ float inv_std; + if (threadIdx.x == 0) { + inv_std = rsqrtf(var_sum / features + eps); + } + __syncthreads(); + + for (int i = threadIdx.x; i < features; i += blockDim.x) { + float x = __half2float(row_input[i]); + float normalized = (x - mean) * inv_std; + float g = __half2float(gamma[i]); + float b = __half2float(beta[i]); + row_output[i] = __float2half(normalized * g + b); + } +} + +// BF16 LayerNorm (compute in FP32 for precision) +__global__ void layernorm_bf16_kernel(const __nv_bfloat16* __restrict__ input, + const __nv_bfloat16* __restrict__ gamma, + const __nv_bfloat16* __restrict__ beta, + __nv_bfloat16* __restrict__ output, + size_t batch_size, + size_t features, + float eps) { + int row = blockIdx.x; + if (row >= batch_size) return; + + const __nv_bfloat16* row_input = input + row * features; + __nv_bfloat16* row_output = output + row * features; + + float sum = 0.0f; + for (int i = threadIdx.x; i < features; i += blockDim.x) { + sum += __bfloat162float(row_input[i]); + } + + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + + __shared__ float shared_sum[32]; + int lane = threadIdx.x % warpSize; + int warp_id = threadIdx.x / warpSize; + + if (lane == 0) { + shared_sum[warp_id] = sum; + } + __syncthreads(); + + if (warp_id == 0) { + sum = (threadIdx.x < (blockDim.x + warpSize - 1) / warpSize) ? shared_sum[threadIdx.x] : 0.0f; + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + } + + __shared__ float mean; + if (threadIdx.x == 0) { + mean = sum / features; + } + __syncthreads(); + + float var_sum = 0.0f; + for (int i = threadIdx.x; i < features; i += blockDim.x) { + float diff = __bfloat162float(row_input[i]) - mean; + var_sum += diff * diff; + } + + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + var_sum += __shfl_down_sync(0xffffffff, var_sum, offset); + } + + if (lane == 0) { + shared_sum[warp_id] = var_sum; + } + __syncthreads(); + + if (warp_id == 0) { + var_sum = (threadIdx.x < (blockDim.x + warpSize - 1) / warpSize) ? shared_sum[threadIdx.x] : 0.0f; + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + var_sum += __shfl_down_sync(0xffffffff, var_sum, offset); + } + } + + __shared__ float inv_std; + if (threadIdx.x == 0) { + inv_std = rsqrtf(var_sum / features + eps); + } + __syncthreads(); + + for (int i = threadIdx.x; i < features; i += blockDim.x) { + float x = __bfloat162float(row_input[i]); + float normalized = (x - mean) * inv_std; + float g = __bfloat162float(gamma[i]); + float b = __bfloat162float(beta[i]); + row_output[i] = __float2bfloat16(normalized * g + b); + } +} + +} // namespace nn +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/ops.cuh b/native/ops/ops.cuh new file mode 100644 index 0000000..734a38d --- /dev/null +++ b/native/ops/ops.cuh @@ -0,0 +1,98 @@ +/** + * PyGPUkit Operations - Public API + * + * This header provides access to all GPU array operations: + * - Elementwise: add, mul, sub, div + * - Unary: exp, log, relu + * - Reduction: sum, mean, max + * - Matmul: matrix multiplication with TensorCore support + */ +#pragma once + +#include "../core/memory.hpp" + +namespace pygpukit { +namespace ops { + +// ============================================================================ +// Elementwise Operations +// ============================================================================ + +// Add: c = a + b +void add(const GPUArray& a, const GPUArray& b, GPUArray& c); +GPUArray add(const GPUArray& a, const GPUArray& b); + +// Mul: c = a * b +void mul(const GPUArray& a, const GPUArray& b, GPUArray& c); +GPUArray mul(const GPUArray& a, const GPUArray& b); + +// Sub: c = a - b +void sub(const GPUArray& a, const GPUArray& b, GPUArray& c); +GPUArray sub(const GPUArray& a, const GPUArray& b); + +// Div: c = a / b +void div(const GPUArray& a, const GPUArray& b, GPUArray& c); +GPUArray div(const GPUArray& a, const GPUArray& b); + +// ============================================================================ +// Unary Operations +// ============================================================================ + +// Exp: c = exp(a) +void exp(const GPUArray& a, GPUArray& c); +GPUArray exp(const GPUArray& a); + +// Log: c = log(a) +void log(const GPUArray& a, GPUArray& c); +GPUArray log(const GPUArray& a); + +// ReLU: c = max(0, a) +void relu(const GPUArray& a, GPUArray& c); +GPUArray relu(const GPUArray& a); + +// ============================================================================ +// Reduction Operations +// ============================================================================ + +// Sum: scalar sum of all elements +GPUArray sum(const GPUArray& a); + +// Mean: scalar mean of all elements +GPUArray mean(const GPUArray& a); + +// Max: scalar max of all elements +GPUArray max(const GPUArray& a); + +// ============================================================================ +// Matrix Multiplication +// ============================================================================ + +// Matmul: c = a @ b +// Automatically selects optimal kernel based on dtype and size: +// - FP32: L2-optimized, tiled, or Ampere-optimized kernel +// - FP32 + PYGPUKIT_ALLOW_TF32=1: TF32 TensorCore kernel +// - FP16/BF16: Simple or TensorCore kernel (PYGPUKIT_ALLOW_FP16_TC=1) +void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c); +GPUArray matmul(const GPUArray& a, const GPUArray& b); + +// Matmul with explicit TF32 control +void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c, bool use_tf32); +GPUArray matmul(const GPUArray& a, const GPUArray& b, bool use_tf32); + +// ============================================================================ +// Neural Network Operations +// ============================================================================ + +// GELU: Gaussian Error Linear Unit activation +// y = x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) +GPUArray gelu(const GPUArray& input); + +// Bias Add: output[batch, features] += bias[features] (in-place) +void bias_add_inplace(GPUArray& output, const GPUArray& bias); + +// LayerNorm: y = (x - mean) / sqrt(var + eps) * gamma + beta +// input: [batch, features], gamma/beta: [features] +GPUArray layernorm(const GPUArray& input, const GPUArray& gamma, const GPUArray& beta, float eps = 1e-5f); + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/reduction/reduction.cu b/native/ops/reduction/reduction.cu new file mode 100644 index 0000000..f1eb7f7 --- /dev/null +++ b/native/ops/reduction/reduction.cu @@ -0,0 +1,197 @@ +/** + * Reduction operations dispatch (sum, mean, max) + */ +#include "reduction_kernels.cuh" +#include "../common/error.cuh" +#include "../../core/memory.hpp" +#include + +namespace pygpukit { +namespace ops { + +using namespace reduction; + +// ============================================================================ +// Sum +// ============================================================================ + +GPUArray sum(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("sum only supports float types"); + } + + GPUArray result({1}, a.dtype()); + size_t n = a.size(); + + const int block_size = 256; + const int max_blocks = 256; // Limit blocks for efficient atomic reduction + const int grid_size = std::min((int)((n + block_size - 1) / block_size), max_blocks); + + switch (a.dtype()) { + case DataType::Float32: + init_sum_f32_kernel<<<1, 1>>>(static_cast(result.data())); + reduce_sum_f32_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + break; + case DataType::Float64: + init_sum_f64_kernel<<<1, 1>>>(static_cast(result.data())); + reduce_sum_f64_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + break; + case DataType::Float16: + init_sum_f16_kernel<<<1, 1>>>(static_cast<__half*>(result.data())); + reduce_sum_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(result.data()), + n); + break; + case DataType::BFloat16: + init_sum_bf16_kernel<<<1, 1>>>(static_cast<__nv_bfloat16*>(result.data())); + reduce_sum_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(result.data()), + n); + break; + default: + break; + } + + sync_and_check("sum kernel failed"); + return result; +} + +// ============================================================================ +// Mean +// ============================================================================ + +GPUArray mean(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("mean only supports float types"); + } + + GPUArray result({1}, a.dtype()); + size_t n = a.size(); + + const int block_size = 256; + const int max_blocks = 256; + const int grid_size = std::min((int)((n + block_size - 1) / block_size), max_blocks); + + switch (a.dtype()) { + case DataType::Float32: { + init_sum_f32_kernel<<<1, 1>>>(static_cast(result.data())); + reduce_sum_f32_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + sync_and_check("mean sum kernel failed"); + scale_f32_kernel<<<1, 1>>>( + static_cast(result.data()), + 1.0f / static_cast(n)); + break; + } + case DataType::Float64: { + init_sum_f64_kernel<<<1, 1>>>(static_cast(result.data())); + reduce_sum_f64_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + sync_and_check("mean sum kernel failed"); + scale_f64_kernel<<<1, 1>>>( + static_cast(result.data()), + 1.0 / static_cast(n)); + break; + } + case DataType::Float16: { + init_sum_f16_kernel<<<1, 1>>>(static_cast<__half*>(result.data())); + reduce_sum_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(result.data()), + n); + sync_and_check("mean sum kernel failed"); + scale_f16_kernel<<<1, 1>>>( + static_cast<__half*>(result.data()), + 1.0f / static_cast(n)); + break; + } + case DataType::BFloat16: { + init_sum_bf16_kernel<<<1, 1>>>(static_cast<__nv_bfloat16*>(result.data())); + reduce_sum_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(result.data()), + n); + sync_and_check("mean sum kernel failed"); + scale_bf16_kernel<<<1, 1>>>( + static_cast<__nv_bfloat16*>(result.data()), + 1.0f / static_cast(n)); + break; + } + default: + break; + } + + sync_and_check("mean kernel failed"); + return result; +} + +// ============================================================================ +// Max +// ============================================================================ + +GPUArray max(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("max only supports float types"); + } + + GPUArray result({1}, a.dtype()); + size_t n = a.size(); + + const int block_size = 256; + const int max_blocks = 256; + const int grid_size = std::min((int)((n + block_size - 1) / block_size), max_blocks); + + switch (a.dtype()) { + case DataType::Float32: + init_max_f32_kernel<<<1, 1>>>(static_cast(result.data())); + reduce_max_f32_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + break; + case DataType::Float64: + init_max_f64_kernel<<<1, 1>>>(static_cast(result.data())); + reduce_max_f64_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + break; + case DataType::Float16: + init_max_f16_kernel<<<1, 1>>>(static_cast<__half*>(result.data())); + reduce_max_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(result.data()), + n); + break; + case DataType::BFloat16: + init_max_bf16_kernel<<<1, 1>>>(static_cast<__nv_bfloat16*>(result.data())); + reduce_max_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(result.data()), + n); + break; + default: + break; + } + + sync_and_check("max kernel failed"); + return result; +} + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/reduction/reduction_kernels.cuh b/native/ops/reduction/reduction_kernels.cuh new file mode 100644 index 0000000..7fa5099 --- /dev/null +++ b/native/ops/reduction/reduction_kernels.cuh @@ -0,0 +1,362 @@ +/** + * Reduction operation kernels (sum, mean, max) + * Uses warp-level shuffle for efficient parallel reduction + */ +#pragma once + +#include +#include +#include +#include +#include "../common/types.cuh" + +namespace pygpukit { +namespace ops { +namespace reduction { + +// ============================================================================ +// Warp-level reduction primitives +// ============================================================================ + +__device__ __forceinline__ float warp_reduce_sum(float val) { + for (int offset = 16; offset > 0; offset /= 2) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + return val; +} + +__device__ __forceinline__ double warp_reduce_sum_f64(double val) { + for (int offset = 16; offset > 0; offset /= 2) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + return val; +} + +__device__ __forceinline__ float warp_reduce_max(float val) { + for (int offset = 16; offset > 0; offset /= 2) { + val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset)); + } + return val; +} + +__device__ __forceinline__ double warp_reduce_max_f64(double val) { + for (int offset = 16; offset > 0; offset /= 2) { + val = fmax(val, __shfl_down_sync(0xffffffff, val, offset)); + } + return val; +} + +// ============================================================================ +// Sum reduction kernels +// ============================================================================ + +__global__ void reduce_sum_f32_kernel(const float* __restrict__ input, float* __restrict__ output, size_t n) { + __shared__ float shared[32]; // One value per warp + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + // Grid-stride loop to accumulate + float sum = 0.0f; + for (size_t i = idx; i < n; i += stride) { + sum += input[i]; + } + + // Warp reduction + sum = warp_reduce_sum(sum); + + // Write warp result to shared memory + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = sum; + } + __syncthreads(); + + // Final reduction by first warp + if (warp_id == 0) { + sum = (tid < (blockDim.x + 31) / 32) ? shared[lane] : 0.0f; + sum = warp_reduce_sum(sum); + if (lane == 0) { + atomicAdd(output, sum); + } + } +} + +__global__ void reduce_sum_f64_kernel(const double* __restrict__ input, double* __restrict__ output, size_t n) { + __shared__ double shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + double sum = 0.0; + for (size_t i = idx; i < n; i += stride) { + sum += input[i]; + } + + sum = warp_reduce_sum_f64(sum); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = sum; + } + __syncthreads(); + + if (warp_id == 0) { + sum = (tid < (blockDim.x + 31) / 32) ? shared[lane] : 0.0; + sum = warp_reduce_sum_f64(sum); + if (lane == 0) { + // atomicAdd for double requires sm_60+ + atomicAdd(output, sum); + } + } +} + +// FP16 reduction - accumulate in FP32 for numerical stability +__global__ void reduce_sum_f16_kernel(const __half* __restrict__ input, __half* __restrict__ output, size_t n) { + __shared__ float shared[32]; // Accumulate in FP32 + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float sum = 0.0f; + for (size_t i = idx; i < n; i += stride) { + sum += __half2float(input[i]); + } + + sum = warp_reduce_sum(sum); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = sum; + } + __syncthreads(); + + if (warp_id == 0) { + sum = (tid < (blockDim.x + 31) / 32) ? shared[lane] : 0.0f; + sum = warp_reduce_sum(sum); + if (lane == 0) { + // Atomic add in FP32, then convert back + float old_val = __half2float(*output); + *output = __float2half(old_val + sum); + } + } +} + +// BF16 reduction - accumulate in FP32 for numerical stability +__global__ void reduce_sum_bf16_kernel(const __nv_bfloat16* __restrict__ input, __nv_bfloat16* __restrict__ output, size_t n) { + __shared__ float shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float sum = 0.0f; + for (size_t i = idx; i < n; i += stride) { + sum += bf16_to_float(input[i]); + } + + sum = warp_reduce_sum(sum); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = sum; + } + __syncthreads(); + + if (warp_id == 0) { + sum = (tid < (blockDim.x + 31) / 32) ? shared[lane] : 0.0f; + sum = warp_reduce_sum(sum); + if (lane == 0) { + float old_val = bf16_to_float(*output); + *output = float_to_bf16(old_val + sum); + } + } +} + +// ============================================================================ +// Max reduction kernels +// ============================================================================ + +__global__ void reduce_max_f32_kernel(const float* __restrict__ input, float* __restrict__ output, size_t n) { + __shared__ float shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float max_val = -INFINITY; + for (size_t i = idx; i < n; i += stride) { + max_val = fmaxf(max_val, input[i]); + } + + max_val = warp_reduce_max(max_val); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = max_val; + } + __syncthreads(); + + if (warp_id == 0) { + max_val = (tid < (blockDim.x + 31) / 32) ? shared[lane] : -INFINITY; + max_val = warp_reduce_max(max_val); + if (lane == 0) { + // Atomic max for float - use atomicMax with int cast trick + int* addr = (int*)output; + int expected = *addr; + while (max_val > __int_as_float(expected)) { + int old = atomicCAS(addr, expected, __float_as_int(max_val)); + if (old == expected) break; + expected = old; + } + } + } +} + +__global__ void reduce_max_f64_kernel(const double* __restrict__ input, double* __restrict__ output, size_t n) { + __shared__ double shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + double max_val = -INFINITY; + for (size_t i = idx; i < n; i += stride) { + max_val = fmax(max_val, input[i]); + } + + max_val = warp_reduce_max_f64(max_val); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = max_val; + } + __syncthreads(); + + if (warp_id == 0) { + max_val = (tid < (blockDim.x + 31) / 32) ? shared[lane] : -INFINITY; + max_val = warp_reduce_max_f64(max_val); + if (lane == 0) { + // Atomic max for double using CAS + unsigned long long* addr = (unsigned long long*)output; + unsigned long long expected = *addr; + while (max_val > __longlong_as_double(expected)) { + unsigned long long old = atomicCAS(addr, expected, __double_as_longlong(max_val)); + if (old == expected) break; + expected = old; + } + } + } +} + +__global__ void reduce_max_f16_kernel(const __half* __restrict__ input, __half* __restrict__ output, size_t n) { + __shared__ float shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float max_val = -INFINITY; + for (size_t i = idx; i < n; i += stride) { + max_val = fmaxf(max_val, __half2float(input[i])); + } + + max_val = warp_reduce_max(max_val); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = max_val; + } + __syncthreads(); + + if (warp_id == 0) { + max_val = (tid < (blockDim.x + 31) / 32) ? shared[lane] : -INFINITY; + max_val = warp_reduce_max(max_val); + if (lane == 0) { + float old_val = __half2float(*output); + if (max_val > old_val) { + *output = __float2half(max_val); + } + } + } +} + +__global__ void reduce_max_bf16_kernel(const __nv_bfloat16* __restrict__ input, __nv_bfloat16* __restrict__ output, size_t n) { + __shared__ float shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float max_val = -INFINITY; + for (size_t i = idx; i < n; i += stride) { + max_val = fmaxf(max_val, bf16_to_float(input[i])); + } + + max_val = warp_reduce_max(max_val); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = max_val; + } + __syncthreads(); + + if (warp_id == 0) { + max_val = (tid < (blockDim.x + 31) / 32) ? shared[lane] : -INFINITY; + max_val = warp_reduce_max(max_val); + if (lane == 0) { + float old_val = bf16_to_float(*output); + if (max_val > old_val) { + *output = float_to_bf16(max_val); + } + } + } +} + +// ============================================================================ +// Output initialization kernels +// ============================================================================ + +__global__ void init_sum_f32_kernel(float* output) { *output = 0.0f; } +__global__ void init_sum_f64_kernel(double* output) { *output = 0.0; } +__global__ void init_sum_f16_kernel(__half* output) { *output = __float2half(0.0f); } +__global__ void init_sum_bf16_kernel(__nv_bfloat16* output) { *output = float_to_bf16(0.0f); } +__global__ void init_max_f32_kernel(float* output) { *output = -INFINITY; } +__global__ void init_max_f64_kernel(double* output) { *output = -INFINITY; } +__global__ void init_max_f16_kernel(__half* output) { *output = __float2half(-INFINITY); } +__global__ void init_max_bf16_kernel(__nv_bfloat16* output) { *output = float_to_bf16(-INFINITY); } + +// ============================================================================ +// Scale kernels (for mean calculation) +// ============================================================================ + +__global__ void scale_f32_kernel(float* data, float scale) { + *data *= scale; +} + +__global__ void scale_f64_kernel(double* data, double scale) { + *data *= scale; +} + +__global__ void scale_f16_kernel(__half* data, float scale) { + *data = __float2half(__half2float(*data) * scale); +} + +__global__ void scale_bf16_kernel(__nv_bfloat16* data, float scale) { + *data = float_to_bf16(bf16_to_float(*data) * scale); +} + +} // namespace reduction +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/unary/unary.cu b/native/ops/unary/unary.cu new file mode 100644 index 0000000..9d6e50f --- /dev/null +++ b/native/ops/unary/unary.cu @@ -0,0 +1,176 @@ +/** + * Unary operations dispatch (exp, log, relu) + */ +#include "unary_kernels.cuh" +#include "../common/error.cuh" +#include "../../core/memory.hpp" + +namespace pygpukit { +namespace ops { + +using namespace unary; + +// ============================================================================ +// Exp +// ============================================================================ + +void exp(const GPUArray& a, GPUArray& c) { + validate_same_shape(a, c, "exp"); + validate_same_dtype(a, c, "exp"); + + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("exp only supports float types"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + exp_f32_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), n); + break; + case DataType::Float64: + exp_f64_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + exp_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + exp_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + default: + break; + } + sync_and_check("exp kernel failed"); +} + +GPUArray exp(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("exp only supports float types"); + } + GPUArray c(a.shape(), a.dtype()); + exp(a, c); + return c; +} + +// ============================================================================ +// Log +// ============================================================================ + +void log(const GPUArray& a, GPUArray& c) { + validate_same_shape(a, c, "log"); + validate_same_dtype(a, c, "log"); + + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("log only supports float types"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + log_f32_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), n); + break; + case DataType::Float64: + log_f64_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + log_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + log_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + default: + break; + } + sync_and_check("log kernel failed"); +} + +GPUArray log(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("log only supports float types"); + } + GPUArray c(a.shape(), a.dtype()); + log(a, c); + return c; +} + +// ============================================================================ +// ReLU +// ============================================================================ + +void relu(const GPUArray& a, GPUArray& c) { + validate_same_shape(a, c, "relu"); + validate_same_dtype(a, c, "relu"); + + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("relu only supports float types"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + relu_f32_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), n); + break; + case DataType::Float64: + relu_f64_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + relu_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + relu_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + default: + break; + } + sync_and_check("relu kernel failed"); +} + +GPUArray relu(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("relu only supports float types"); + } + GPUArray c(a.shape(), a.dtype()); + relu(a, c); + return c; +} + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/unary/unary_kernels.cuh b/native/ops/unary/unary_kernels.cuh new file mode 100644 index 0000000..a434e4c --- /dev/null +++ b/native/ops/unary/unary_kernels.cuh @@ -0,0 +1,116 @@ +/** + * Unary operation kernels (exp, log, relu) + */ +#pragma once + +#include +#include +#include +#include +#include "../common/types.cuh" + +namespace pygpukit { +namespace ops { +namespace unary { + +// ============================================================================ +// Exp kernels +// ============================================================================ + +__global__ void exp_f32_kernel(const float* a, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = expf(a[idx]); + } +} + +__global__ void exp_f64_kernel(const double* a, double* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = ::exp(a[idx]); + } +} + +__global__ void exp_f16_kernel(const __half* a, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __float2half(expf(__half2float(a[idx]))); + } +} + +__global__ void exp_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = float_to_bf16(expf(bf16_to_float(a[idx]))); + } +} + +// ============================================================================ +// Log kernels +// ============================================================================ + +__global__ void log_f32_kernel(const float* a, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = logf(a[idx]); + } +} + +__global__ void log_f64_kernel(const double* a, double* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = ::log(a[idx]); + } +} + +__global__ void log_f16_kernel(const __half* a, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __float2half(logf(__half2float(a[idx]))); + } +} + +__global__ void log_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = float_to_bf16(logf(bf16_to_float(a[idx]))); + } +} + +// ============================================================================ +// ReLU kernels +// ============================================================================ + +__global__ void relu_f32_kernel(const float* a, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = fmaxf(0.0f, a[idx]); + } +} + +__global__ void relu_f64_kernel(const double* a, double* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = fmax(0.0, a[idx]); + } +} + +__global__ void relu_f16_kernel(const __half* a, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float val = __half2float(a[idx]); + c[idx] = __float2half(val > 0.0f ? val : 0.0f); + } +} + +__global__ void relu_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float val = bf16_to_float(a[idx]); + c[idx] = float_to_bf16(val > 0.0f ? val : 0.0f); + } +} + +} // namespace unary +} // namespace ops +} // namespace pygpukit diff --git a/pyproject.toml b/pyproject.toml index 3c8afaa..6e8b442 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build" [project] name = "PyGPUkit" -version = "0.2.5" +version = "0.2.6" description = "A lightweight GPU runtime for Python with Rust-powered scheduler, NVRTC JIT compilation, and NumPy-like API" readme = "README.md" license = "MIT" diff --git a/rust/Cargo.lock b/rust/Cargo.lock index da2c16a..16fe527 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -164,6 +164,15 @@ version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +[[package]] +name = "memmap2" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "744133e4a0e0a658e1374cf3bf8e415c4052a15a111acd372764c55b4177d490" +dependencies = [ + "libc", +] + [[package]] name = "memoffset" version = "0.9.1" @@ -295,7 +304,9 @@ version = "0.2.0" dependencies = [ "dirs", "indexmap", + "memmap2", "parking_lot", + "safetensors", "serde", "serde_json", ] @@ -434,6 +445,16 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "safetensors" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "scopeguard" version = "1.2.0" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 2802dfd..284e947 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -17,3 +17,5 @@ uuid = { version = "1.11", features = ["v4"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" dirs = "5.0" +safetensors = "0.4" +memmap2 = "0.9" diff --git a/rust/pygpukit-core/Cargo.toml b/rust/pygpukit-core/Cargo.toml index 3f154ca..f1e1013 100644 --- a/rust/pygpukit-core/Cargo.toml +++ b/rust/pygpukit-core/Cargo.toml @@ -11,3 +11,5 @@ indexmap.workspace = true serde.workspace = true serde_json.workspace = true dirs.workspace = true +safetensors.workspace = true +memmap2.workspace = true diff --git a/rust/pygpukit-core/src/lib.rs b/rust/pygpukit-core/src/lib.rs index cf47f0d..38ad6ba 100644 --- a/rust/pygpukit-core/src/lib.rs +++ b/rust/pygpukit-core/src/lib.rs @@ -13,6 +13,7 @@ pub mod scheduler; pub mod transfer; pub mod dispatch; pub mod device; +pub mod llm; pub use memory::{MemoryBlock, MemoryPool, PoolStats, MemoryError}; pub use scheduler::{ @@ -20,6 +21,7 @@ pub use scheduler::{ AdmissionController, AdmissionConfig, AdmissionDecision, AdmissionStats, RejectReason, QosClass, QosPolicy, QosTaskMeta, QosEvaluation, QosPolicyEvaluator, QosStats, ResourceRequirements, PartitionManager, PartitionConfig, Partition, PartitionLimits, PartitionUsage, PartitionStats, PartitionError, + ExecutionContext, ContextState, ContextStats, MultiLLMController, ControllerStats, }; pub use transfer::{ TransferType, TransferOp, TransferState, AsyncTransferEngine, StreamType, TransferStats, @@ -32,3 +34,8 @@ pub use dispatch::{ KernelCache, CacheConfig, CachedKernel, CompileOptions, CacheStats, }; pub use device::{KernelType, DeviceCapabilities}; +pub use llm::{ + SafeTensorsFile, TensorInfo, TensorData, SafeTensorsError, + Dtype, load_safetensors, + Tokenizer, TokenizerError, +}; diff --git a/rust/pygpukit-core/src/llm/mod.rs b/rust/pygpukit-core/src/llm/mod.rs new file mode 100644 index 0000000..459df1e --- /dev/null +++ b/rust/pygpukit-core/src/llm/mod.rs @@ -0,0 +1,16 @@ +//! LLM support module for PyGPUkit +//! +//! Provides: +//! - safetensors file loading +//! - Tensor metadata and data access +//! - GPU tensor allocation helpers +//! - BPE tokenizer for GPT-2 style models + +pub mod tensor_loader; +pub mod tokenizer; + +pub use tensor_loader::{ + SafeTensorsFile, TensorInfo, TensorData, SafeTensorsError, + Dtype, load_safetensors, +}; +pub use tokenizer::{Tokenizer, TokenizerError}; diff --git a/rust/pygpukit-core/src/llm/tensor_loader.rs b/rust/pygpukit-core/src/llm/tensor_loader.rs new file mode 100644 index 0000000..85ed0b3 --- /dev/null +++ b/rust/pygpukit-core/src/llm/tensor_loader.rs @@ -0,0 +1,296 @@ +//! SafeTensors file loader for PyGPUkit +//! +//! Provides memory-mapped loading of safetensors files for efficient +//! GPU tensor allocation. + +use memmap2::Mmap; +use safetensors::SafeTensors; +use std::collections::HashMap; +use std::fs::File; +use std::path::Path; + +/// Error type for SafeTensors operations +#[derive(Debug)] +pub enum SafeTensorsError { + IoError(std::io::Error), + ParseError(String), + TensorNotFound(String), + UnsupportedDtype(String), +} + +impl std::fmt::Display for SafeTensorsError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SafeTensorsError::IoError(e) => write!(f, "IO error: {}", e), + SafeTensorsError::ParseError(e) => write!(f, "Parse error: {}", e), + SafeTensorsError::TensorNotFound(name) => write!(f, "Tensor not found: {}", name), + SafeTensorsError::UnsupportedDtype(dtype) => write!(f, "Unsupported dtype: {}", dtype), + } + } +} + +impl std::error::Error for SafeTensorsError {} + +impl From for SafeTensorsError { + fn from(e: std::io::Error) -> Self { + SafeTensorsError::IoError(e) + } +} + +impl From for SafeTensorsError { + fn from(e: safetensors::SafeTensorError) -> Self { + SafeTensorsError::ParseError(e.to_string()) + } +} + +/// Data type for tensor elements +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Dtype { + Float32, + Float16, + BFloat16, + Float64, + Int32, + Int64, + Int16, + Int8, + UInt8, + Bool, +} + +impl Dtype { + /// Size in bytes of a single element + pub fn element_size(&self) -> usize { + match self { + Dtype::Float64 | Dtype::Int64 => 8, + Dtype::Float32 | Dtype::Int32 => 4, + Dtype::Float16 | Dtype::BFloat16 | Dtype::Int16 => 2, + Dtype::Int8 | Dtype::UInt8 | Dtype::Bool => 1, + } + } + + /// Convert from safetensors dtype string + pub fn from_safetensors(dtype: safetensors::Dtype) -> Result { + match dtype { + safetensors::Dtype::F32 => Ok(Dtype::Float32), + safetensors::Dtype::F16 => Ok(Dtype::Float16), + safetensors::Dtype::BF16 => Ok(Dtype::BFloat16), + safetensors::Dtype::F64 => Ok(Dtype::Float64), + safetensors::Dtype::I32 => Ok(Dtype::Int32), + safetensors::Dtype::I64 => Ok(Dtype::Int64), + safetensors::Dtype::I16 => Ok(Dtype::Int16), + safetensors::Dtype::I8 => Ok(Dtype::Int8), + safetensors::Dtype::U8 => Ok(Dtype::UInt8), + safetensors::Dtype::BOOL => Ok(Dtype::Bool), + _ => Err(SafeTensorsError::UnsupportedDtype(format!("{:?}", dtype))), + } + } +} + +/// Metadata for a single tensor +#[derive(Debug, Clone)] +pub struct TensorInfo { + /// Tensor name (key in safetensors file) + pub name: String, + /// Data type + pub dtype: Dtype, + /// Shape dimensions + pub shape: Vec, + /// Byte offset within the data section + pub offset: usize, + /// Total size in bytes + pub size_bytes: usize, +} + +impl TensorInfo { + /// Total number of elements + pub fn numel(&self) -> usize { + self.shape.iter().product() + } +} + +/// View into tensor data (zero-copy reference to mmap) +pub struct TensorData<'a> { + /// Tensor metadata + pub info: TensorInfo, + /// Raw bytes (slice of mmap) + pub data: &'a [u8], +} + +impl<'a> TensorData<'a> { + /// Get data as f32 slice (only valid if dtype is Float32) + pub fn as_f32(&self) -> Option<&[f32]> { + if self.info.dtype != Dtype::Float32 { + return None; + } + // Safety: data is aligned and valid for f32 + let ptr = self.data.as_ptr() as *const f32; + let len = self.data.len() / 4; + Some(unsafe { std::slice::from_raw_parts(ptr, len) }) + } + + /// Get data as f16 bytes (raw bytes, 2 per element) + pub fn as_f16_bytes(&self) -> Option<&[u8]> { + if self.info.dtype != Dtype::Float16 { + return None; + } + Some(self.data) + } + + /// Get data as bf16 bytes (raw bytes, 2 per element) + pub fn as_bf16_bytes(&self) -> Option<&[u8]> { + if self.info.dtype != Dtype::BFloat16 { + return None; + } + Some(self.data) + } +} + +/// Memory-mapped SafeTensors file +pub struct SafeTensorsFile { + /// Memory-mapped file data + _mmap: Mmap, + /// Parsed header with tensor metadata + tensor_infos: HashMap, + /// Offset to data section start + data_offset: usize, + /// Raw pointer to mmap data (for creating tensor views) + data_ptr: *const u8, + /// Total file size + file_size: usize, +} + +// Safety: SafeTensorsFile is Send because the mmap is read-only +// and the data_ptr points to immutable memory +unsafe impl Send for SafeTensorsFile {} +unsafe impl Sync for SafeTensorsFile {} + +impl SafeTensorsFile { + /// Open a safetensors file with memory mapping + pub fn open>(path: P) -> Result { + let file = File::open(path.as_ref())?; + let mmap = unsafe { Mmap::map(&file)? }; + let file_size = mmap.len(); + + // Parse using safetensors crate + let tensors = SafeTensors::deserialize(&mmap)?; + + // Extract tensor info + let mut tensor_infos = HashMap::new(); + for (name, view) in tensors.tensors() { + let dtype = Dtype::from_safetensors(view.dtype())?; + let shape: Vec = view.shape().to_vec(); + let data = view.data(); + + // Calculate offset from mmap start + let data_ptr = data.as_ptr(); + let mmap_ptr = mmap.as_ptr(); + let offset = data_ptr as usize - mmap_ptr as usize; + + let info = TensorInfo { + name: name.to_string(), + dtype, + shape, + offset, + size_bytes: data.len(), + }; + tensor_infos.insert(name.to_string(), info); + } + + // Data offset is after the header + // Header format: 8-byte size + JSON header + data + let header_size = u64::from_le_bytes(mmap[0..8].try_into().unwrap()) as usize; + let data_offset = 8 + header_size; + + let data_ptr = mmap.as_ptr(); + + Ok(SafeTensorsFile { + _mmap: mmap, + tensor_infos, + data_offset, + data_ptr, + file_size, + }) + } + + /// Get list of all tensor names + pub fn tensor_names(&self) -> Vec<&str> { + self.tensor_infos.keys().map(|s| s.as_str()).collect() + } + + /// Get tensor info by name + pub fn tensor_info(&self, name: &str) -> Option<&TensorInfo> { + self.tensor_infos.get(name) + } + + /// Get tensor data by name (zero-copy view into mmap) + pub fn tensor(&self, name: &str) -> Result, SafeTensorsError> { + let info = self + .tensor_infos + .get(name) + .ok_or_else(|| SafeTensorsError::TensorNotFound(name.to_string()))?; + + // Safety: offset and size are validated during parsing + let data = unsafe { + std::slice::from_raw_parts(self.data_ptr.add(info.offset), info.size_bytes) + }; + + Ok(TensorData { + info: info.clone(), + data, + }) + } + + /// Get all tensors as an iterator + pub fn tensors(&self) -> impl Iterator, SafeTensorsError>> { + self.tensor_infos + .keys() + .map(|name| self.tensor(name)) + } + + /// Total file size in bytes + pub fn file_size(&self) -> usize { + self.file_size + } + + /// Number of tensors in the file + pub fn num_tensors(&self) -> usize { + self.tensor_infos.len() + } + + /// Data section offset (after header) + pub fn data_offset(&self) -> usize { + self.data_offset + } +} + +/// Convenience function to load a safetensors file +pub fn load_safetensors>(path: P) -> Result { + SafeTensorsFile::open(path) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dtype_element_size() { + assert_eq!(Dtype::Float32.element_size(), 4); + assert_eq!(Dtype::Float16.element_size(), 2); + assert_eq!(Dtype::BFloat16.element_size(), 2); + assert_eq!(Dtype::Float64.element_size(), 8); + assert_eq!(Dtype::Int8.element_size(), 1); + } + + #[test] + fn test_tensor_info_numel() { + let info = TensorInfo { + name: "test".to_string(), + dtype: Dtype::Float32, + shape: vec![2, 3, 4], + offset: 0, + size_bytes: 96, + }; + assert_eq!(info.numel(), 24); + } +} diff --git a/rust/pygpukit-core/src/llm/tokenizer.rs b/rust/pygpukit-core/src/llm/tokenizer.rs new file mode 100644 index 0000000..ad60188 --- /dev/null +++ b/rust/pygpukit-core/src/llm/tokenizer.rs @@ -0,0 +1,377 @@ +//! Simple BPE tokenizer for GPT-2 style models +//! +//! Loads tokenizer.json format and provides basic encode/decode functionality. + +use serde::Deserialize; +use std::collections::HashMap; +use std::fs::File; +use std::io::BufReader; +use std::path::Path; + +/// Error type for tokenizer operations +#[derive(Debug)] +pub enum TokenizerError { + IoError(std::io::Error), + ParseError(String), + InvalidToken(String), +} + +impl std::fmt::Display for TokenizerError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TokenizerError::IoError(e) => write!(f, "IO error: {}", e), + TokenizerError::ParseError(e) => write!(f, "Parse error: {}", e), + TokenizerError::InvalidToken(t) => write!(f, "Invalid token: {}", t), + } + } +} + +impl std::error::Error for TokenizerError {} + +impl From for TokenizerError { + fn from(e: std::io::Error) -> Self { + TokenizerError::IoError(e) + } +} + +impl From for TokenizerError { + fn from(e: serde_json::Error) -> Self { + TokenizerError::ParseError(e.to_string()) + } +} + +/// GPT-2 style tokenizer.json model section +#[derive(Debug, Deserialize)] +struct TokenizerModel { + #[serde(rename = "type")] + model_type: Option, + vocab: HashMap, + merges: Option>, +} + +/// GPT-2 style tokenizer.json added_tokens section +#[derive(Debug, Deserialize)] +struct AddedToken { + id: u32, + content: String, + #[serde(default)] + special: bool, +} + +/// GPT-2 style tokenizer.json format +#[derive(Debug, Deserialize)] +struct TokenizerJson { + model: TokenizerModel, + #[serde(default)] + added_tokens: Vec, +} + +/// BPE merge rule +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct BpePair(String, String); + +/// Simple BPE tokenizer +pub struct Tokenizer { + /// Token string to ID mapping + encoder: HashMap, + /// ID to token string mapping + decoder: HashMap, + /// BPE merge rules (pair -> rank, lower is earlier) + bpe_ranks: HashMap, + /// Special tokens + special_tokens: HashMap, + /// Vocabulary size + vocab_size: usize, +} + +impl Tokenizer { + /// Load tokenizer from tokenizer.json file + pub fn from_file>(path: P) -> Result { + let file = File::open(path)?; + let reader = BufReader::new(file); + let tokenizer_json: TokenizerJson = serde_json::from_reader(reader)?; + + Self::from_json(tokenizer_json) + } + + /// Load tokenizer from JSON string + pub fn from_json_str(json: &str) -> Result { + let tokenizer_json: TokenizerJson = serde_json::from_str(json)?; + Self::from_json(tokenizer_json) + } + + fn from_json(json: TokenizerJson) -> Result { + let encoder = json.model.vocab; + let mut decoder: HashMap = HashMap::new(); + + for (token, id) in &encoder { + decoder.insert(*id, token.clone()); + } + + // Parse BPE merges + let mut bpe_ranks = HashMap::new(); + if let Some(merges) = json.model.merges { + for (rank, merge) in merges.iter().enumerate() { + let parts: Vec<&str> = merge.split(' ').collect(); + if parts.len() == 2 { + let pair = BpePair(parts[0].to_string(), parts[1].to_string()); + bpe_ranks.insert(pair, rank); + } + } + } + + // Process added tokens (special tokens) + let mut special_tokens = HashMap::new(); + for added in json.added_tokens { + if added.special { + special_tokens.insert(added.content.clone(), added.id); + } + // Also add to encoder/decoder + if !encoder.contains_key(&added.content) { + decoder.insert(added.id, added.content); + } + } + + let vocab_size = encoder.len(); + + Ok(Tokenizer { + encoder, + decoder, + bpe_ranks, + special_tokens, + vocab_size, + }) + } + + /// Get vocabulary size + pub fn vocab_size(&self) -> usize { + self.vocab_size + } + + /// Get BOS token ID if available + pub fn bos_token_id(&self) -> Option { + self.special_tokens.get("<|endoftext|>").copied() + .or_else(|| self.special_tokens.get("").copied()) + } + + /// Get EOS token ID if available + pub fn eos_token_id(&self) -> Option { + self.special_tokens.get("<|endoftext|>").copied() + .or_else(|| self.special_tokens.get("").copied()) + } + + /// Get PAD token ID if available + pub fn pad_token_id(&self) -> Option { + self.special_tokens.get("<|padding|>").copied() + .or_else(|| self.special_tokens.get("").copied()) + } + + /// Convert bytes to unicode representation (GPT-2 style) + fn byte_to_unicode() -> HashMap { + let mut byte_encoder: HashMap = HashMap::new(); + let mut n = 0u32; + + // Printable ASCII range and some extended chars + for b in 33u8..=126 { + byte_encoder.insert(b, char::from_u32(b as u32).unwrap()); + } + for b in 161u8..=172 { + byte_encoder.insert(b, char::from_u32(b as u32).unwrap()); + } + for b in 174u8..=255 { + byte_encoder.insert(b, char::from_u32(b as u32).unwrap()); + } + + // Map remaining bytes to unicode codepoints starting at 256 + for b in 0u8..=255 { + if !byte_encoder.contains_key(&b) { + byte_encoder.insert(b, char::from_u32(256 + n).unwrap()); + n += 1; + } + } + + byte_encoder + } + + /// Convert unicode back to bytes + fn unicode_to_byte() -> HashMap { + let byte_encoder = Self::byte_to_unicode(); + byte_encoder.into_iter().map(|(k, v)| (v, k)).collect() + } + + /// Get consecutive pairs from a list of symbols + fn get_pairs(word: &[String]) -> Vec { + let mut pairs = Vec::new(); + for i in 0..word.len().saturating_sub(1) { + pairs.push(BpePair(word[i].clone(), word[i + 1].clone())); + } + pairs + } + + /// Apply BPE to a word + fn bpe(&self, token: &str) -> Vec { + if token.is_empty() { + return vec![]; + } + + // Convert token to unicode chars (GPT-2 byte encoding) + let byte_encoder = Self::byte_to_unicode(); + let word: Vec = token + .bytes() + .map(|b| byte_encoder.get(&b).unwrap_or(&'?').to_string()) + .collect(); + + if word.len() == 1 { + return word; + } + + let mut word = word; + + loop { + let pairs = Self::get_pairs(&word); + if pairs.is_empty() { + break; + } + + // Find the pair with lowest rank (highest priority) + let best_pair = pairs + .iter() + .filter_map(|p| self.bpe_ranks.get(p).map(|r| (p, r))) + .min_by_key(|(_, r)| *r); + + let Some((bigram, _)) = best_pair else { + break; + }; + + // Merge the best pair + let mut new_word = Vec::new(); + let mut i = 0; + + while i < word.len() { + // Find next occurrence of first element of bigram + let j = word[i..].iter().position(|s| *s == bigram.0); + + if let Some(j) = j { + new_word.extend(word[i..i + j].iter().cloned()); + i += j; + + if i < word.len() - 1 && word[i] == bigram.0 && word[i + 1] == bigram.1 { + // Merge + new_word.push(format!("{}{}", bigram.0, bigram.1)); + i += 2; + } else { + new_word.push(word[i].clone()); + i += 1; + } + } else { + new_word.extend(word[i..].iter().cloned()); + break; + } + } + + word = new_word; + + if word.len() == 1 { + break; + } + } + + word + } + + /// Encode text to token IDs + pub fn encode(&self, text: &str) -> Vec { + let mut tokens = Vec::new(); + + // Simple word-level tokenization (split on whitespace and punctuation) + // GPT-2 uses a more complex regex, but this is MVP + let words = Self::simple_tokenize(text); + + for word in words { + let bpe_tokens = self.bpe(&word); + for bpe_token in bpe_tokens { + if let Some(&id) = self.encoder.get(&bpe_token) { + tokens.push(id); + } + } + } + + tokens + } + + /// Simple tokenization (split on whitespace, keep leading space) + fn simple_tokenize(text: &str) -> Vec { + let mut tokens = Vec::new(); + let mut current = String::new(); + + for ch in text.chars() { + if ch.is_whitespace() { + if !current.is_empty() { + tokens.push(current); + current = String::new(); + } + // GPT-2 encodes space as part of next token + current.push(ch); + } else { + current.push(ch); + } + } + + if !current.is_empty() { + tokens.push(current); + } + + tokens + } + + /// Decode token IDs to text + pub fn decode(&self, token_ids: &[u32]) -> String { + let unicode_to_byte = Self::unicode_to_byte(); + + let mut bytes = Vec::new(); + + for &id in token_ids { + if let Some(token) = self.decoder.get(&id) { + for ch in token.chars() { + if let Some(&b) = unicode_to_byte.get(&ch) { + bytes.push(b); + } + } + } + } + + String::from_utf8_lossy(&bytes).to_string() + } + + /// Get token string for an ID + pub fn id_to_token(&self, id: u32) -> Option<&str> { + self.decoder.get(&id).map(|s| s.as_str()) + } + + /// Get ID for a token string + pub fn token_to_id(&self, token: &str) -> Option { + self.encoder.get(token).copied() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_byte_to_unicode_coverage() { + let byte_encoder = Tokenizer::byte_to_unicode(); + // Should have mapping for all 256 bytes + assert_eq!(byte_encoder.len(), 256); + } + + #[test] + fn test_unicode_to_byte_inverse() { + let byte_encoder = Tokenizer::byte_to_unicode(); + let unicode_decoder = Tokenizer::unicode_to_byte(); + + for (b, c) in &byte_encoder { + assert_eq!(unicode_decoder.get(c), Some(b)); + } + } +} diff --git a/rust/pygpukit-core/src/scheduler/async_exec.rs b/rust/pygpukit-core/src/scheduler/async_exec.rs new file mode 100644 index 0000000..fcb532b --- /dev/null +++ b/rust/pygpukit-core/src/scheduler/async_exec.rs @@ -0,0 +1,656 @@ +//! Asynchronous Kernel Execution +//! +//! Provides non-blocking kernel dispatch with Future-based result retrieval: +//! - KernelFuture: Handle for tracking async kernel execution +//! - AsyncExecutor: Manages async kernel lifecycle per stream +//! +//! Design: +//! - dispatch_async() returns immediately with a KernelFuture +//! - Kernel executes on dedicated CUDA stream +//! - wait() blocks until kernel completes (stream synchronize) +//! - is_ready() checks completion without blocking + +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; +use parking_lot::{RwLock, Mutex}; + +/// State of an async kernel execution +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FutureState { + /// Kernel is queued but not yet launched + Pending, + /// Kernel has been launched, executing on GPU + Running, + /// Kernel execution completed successfully + Completed, + /// Kernel execution failed + Failed, + /// Kernel was cancelled + Cancelled, +} + +impl FutureState { + pub fn is_terminal(&self) -> bool { + matches!(self, FutureState::Completed | FutureState::Failed | FutureState::Cancelled) + } +} + +/// Result of an async kernel execution +#[derive(Debug, Clone)] +pub struct KernelResult { + /// Whether execution succeeded + pub success: bool, + /// Error message if failed + pub error: Option, + /// Execution time in seconds + pub exec_time: f64, + /// Output data (if any) + pub output: Option>, +} + +impl KernelResult { + pub fn success(exec_time: f64) -> Self { + Self { + success: true, + error: None, + exec_time, + output: None, + } + } + + pub fn failure(error: String) -> Self { + Self { + success: false, + error: Some(error), + exec_time: 0.0, + output: None, + } + } + + pub fn with_output(mut self, output: Vec) -> Self { + self.output = Some(output); + self + } +} + +/// Internal state for a kernel future +struct FutureInner { + state: FutureState, + result: Option, + launched_at: Option, + completed_at: Option, +} + +/// Handle for tracking async kernel execution +/// +/// Created by `AsyncExecutor::dispatch()`. Use `wait()` to block until +/// completion or `is_ready()` to check without blocking. +/// +/// # Example +/// +/// ```ignore +/// let future = executor.dispatch(request); +/// +/// // Do other work while kernel executes... +/// +/// if future.is_ready() { +/// let result = future.wait(); +/// } +/// ``` +pub struct KernelFuture { + /// Unique ID for this future + id: u64, + /// Stream ID where kernel is executing + stream_id: u32, + /// Context ID (LLM ID) + context_id: String, + /// Shared state + inner: Arc>, + /// Flag for quick ready check + ready: Arc, +} + +impl KernelFuture { + /// Create a new pending future + fn new(id: u64, stream_id: u32, context_id: String) -> Self { + Self { + id, + stream_id, + context_id, + inner: Arc::new(RwLock::new(FutureInner { + state: FutureState::Pending, + result: None, + launched_at: None, + completed_at: None, + })), + ready: Arc::new(AtomicBool::new(false)), + } + } + + /// Get future ID + pub fn id(&self) -> u64 { + self.id + } + + /// Get stream ID + pub fn stream_id(&self) -> u32 { + self.stream_id + } + + /// Get context ID + pub fn context_id(&self) -> &str { + &self.context_id + } + + /// Check if kernel execution is complete (non-blocking) + pub fn is_ready(&self) -> bool { + self.ready.load(Ordering::SeqCst) + } + + /// Get current state + pub fn state(&self) -> FutureState { + self.inner.read().state + } + + /// Wait for kernel completion (blocking) + /// + /// Returns the kernel result. If already complete, returns immediately. + /// If still running, blocks until completion. + /// + /// Note: The actual blocking is done by C++ backend via stream synchronize. + /// This method just returns the cached result after sync. + pub fn wait(&self) -> KernelResult { + // Spin-wait with yield (in practice, C++ backend does the real sync) + while !self.is_ready() { + std::thread::yield_now(); + } + + let inner = self.inner.read(); + inner.result.clone().unwrap_or_else(|| KernelResult::failure("No result available".into())) + } + + /// Try to get result without blocking + pub fn try_get(&self) -> Option { + if self.is_ready() { + let inner = self.inner.read(); + inner.result.clone() + } else { + None + } + } + + /// Get execution time (if completed) + pub fn exec_time(&self) -> Option { + let inner = self.inner.read(); + match (inner.launched_at, inner.completed_at) { + (Some(start), Some(end)) => Some(end - start), + _ => None, + } + } + + // --- Internal methods (called by AsyncExecutor) --- + + fn mark_launched(&self) { + let mut inner = self.inner.write(); + inner.state = FutureState::Running; + inner.launched_at = Some(Self::now()); + } + + fn mark_completed(&self, result: KernelResult) { + let mut inner = self.inner.write(); + inner.state = FutureState::Completed; + inner.completed_at = Some(Self::now()); + inner.result = Some(result); + drop(inner); + self.ready.store(true, Ordering::SeqCst); + } + + fn mark_failed(&self, error: String) { + let mut inner = self.inner.write(); + inner.state = FutureState::Failed; + inner.completed_at = Some(Self::now()); + inner.result = Some(KernelResult::failure(error)); + drop(inner); + self.ready.store(true, Ordering::SeqCst); + } + + fn mark_cancelled(&self) { + let mut inner = self.inner.write(); + inner.state = FutureState::Cancelled; + inner.completed_at = Some(Self::now()); + inner.result = Some(KernelResult::failure("Cancelled".into())); + drop(inner); + self.ready.store(true, Ordering::SeqCst); + } + + fn now() -> f64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0) + } +} + +// Clone creates a new handle to the same future +impl Clone for KernelFuture { + fn clone(&self) -> Self { + Self { + id: self.id, + stream_id: self.stream_id, + context_id: self.context_id.clone(), + inner: Arc::clone(&self.inner), + ready: Arc::clone(&self.ready), + } + } +} + +/// Async kernel request +#[derive(Debug, Clone)] +pub struct AsyncKernelRequest { + /// Kernel function handle (CUfunction as u64) + pub kernel_handle: u64, + /// Grid dimensions (x, y, z) + pub grid: (u32, u32, u32), + /// Block dimensions (x, y, z) + pub block: (u32, u32, u32), + /// Shared memory size + pub shared_mem: u32, + /// Kernel arguments as raw pointers + pub args: Vec, + /// Optional callback ID for completion notification + pub callback_id: Option, +} + +impl AsyncKernelRequest { + pub fn new(kernel_handle: u64) -> Self { + Self { + kernel_handle, + grid: (1, 1, 1), + block: (256, 1, 1), + shared_mem: 0, + args: Vec::new(), + callback_id: None, + } + } + + pub fn with_grid(mut self, x: u32, y: u32, z: u32) -> Self { + self.grid = (x, y, z); + self + } + + pub fn with_block(mut self, x: u32, y: u32, z: u32) -> Self { + self.block = (x, y, z); + self + } + + pub fn with_shared_mem(mut self, bytes: u32) -> Self { + self.shared_mem = bytes; + self + } + + pub fn with_args(mut self, args: Vec) -> Self { + self.args = args; + self + } + + pub fn linear(kernel_handle: u64, n_elements: usize, block_size: u32) -> Self { + let grid_x = ((n_elements as u32) + block_size - 1) / block_size; + Self::new(kernel_handle) + .with_grid(grid_x, 1, 1) + .with_block(block_size, 1, 1) + } +} + +/// Statistics for async executor +#[derive(Debug, Clone, Default)] +pub struct AsyncExecStats { + /// Total dispatches + pub total_dispatched: u64, + /// Currently pending (not yet launched) + pub pending_count: usize, + /// Currently running + pub running_count: usize, + /// Completed successfully + pub completed_count: u64, + /// Failed + pub failed_count: u64, + /// Cancelled + pub cancelled_count: u64, + /// Average execution time + pub avg_exec_time: f64, +} + +/// Internal executor state +struct ExecutorInner { + /// All futures by ID + futures: HashMap, + /// Pending queue per stream + pending: HashMap>, + /// Running per stream + running: HashMap>, + /// Stats + total_exec_time: f64, + completed_count: u64, + failed_count: u64, + cancelled_count: u64, +} + +/// Async kernel executor +/// +/// Manages async kernel dispatch and completion tracking per stream. +/// Each ExecutionContext has its own AsyncExecutor. +pub struct AsyncExecutor { + /// Context ID (LLM ID) + context_id: String, + /// Stream ID for this executor + stream_id: u32, + /// Next future ID + next_id: AtomicU64, + /// Internal state + inner: Mutex, +} + +impl AsyncExecutor { + /// Create a new executor for a context + pub fn new(context_id: String, stream_id: u32) -> Self { + Self { + context_id, + stream_id, + next_id: AtomicU64::new(1), + inner: Mutex::new(ExecutorInner { + futures: HashMap::new(), + pending: HashMap::new(), + running: HashMap::new(), + total_exec_time: 0.0, + completed_count: 0, + failed_count: 0, + cancelled_count: 0, + }), + } + } + + /// Dispatch an async kernel + /// + /// Returns a KernelFuture that can be used to wait for completion. + /// The kernel is queued for execution on this executor's stream. + pub fn dispatch(&self, _request: AsyncKernelRequest) -> KernelFuture { + let id = self.next_id.fetch_add(1, Ordering::SeqCst); + let future = KernelFuture::new(id, self.stream_id, self.context_id.clone()); + + let mut inner = self.inner.lock(); + inner.futures.insert(id, future.clone()); + inner.pending.entry(self.stream_id).or_default().push(id); + + future + } + + /// Get futures ready for launch + /// + /// Returns future IDs that should be launched via C++ backend. + pub fn get_pending(&self) -> Vec { + let inner = self.inner.lock(); + inner.pending.get(&self.stream_id).cloned().unwrap_or_default() + } + + /// Mark a future as launched + pub fn mark_launched(&self, future_id: u64) { + let mut inner = self.inner.lock(); + + // Remove from pending + if let Some(pending) = inner.pending.get_mut(&self.stream_id) { + pending.retain(|&id| id != future_id); + } + + // Add to running + inner.running.entry(self.stream_id).or_default().push(future_id); + + // Update future state + if let Some(future) = inner.futures.get(&future_id) { + future.mark_launched(); + } + } + + /// Mark a future as completed + pub fn mark_completed(&self, future_id: u64, exec_time: f64) { + let mut inner = self.inner.lock(); + + // Remove from running + if let Some(running) = inner.running.get_mut(&self.stream_id) { + running.retain(|&id| id != future_id); + } + + // Update stats + inner.total_exec_time += exec_time; + inner.completed_count += 1; + + // Update future state + if let Some(future) = inner.futures.get(&future_id) { + future.mark_completed(KernelResult::success(exec_time)); + } + } + + /// Mark a future as failed + pub fn mark_failed(&self, future_id: u64, error: String) { + let mut inner = self.inner.lock(); + + // Remove from pending or running + if let Some(pending) = inner.pending.get_mut(&self.stream_id) { + pending.retain(|&id| id != future_id); + } + if let Some(running) = inner.running.get_mut(&self.stream_id) { + running.retain(|&id| id != future_id); + } + + inner.failed_count += 1; + + if let Some(future) = inner.futures.get(&future_id) { + future.mark_failed(error); + } + } + + /// Cancel a pending future + pub fn cancel(&self, future_id: u64) -> bool { + let mut inner = self.inner.lock(); + + // Can only cancel pending futures + let was_pending = if let Some(pending) = inner.pending.get_mut(&self.stream_id) { + let before = pending.len(); + pending.retain(|&id| id != future_id); + pending.len() < before + } else { + false + }; + + if was_pending { + inner.cancelled_count += 1; + if let Some(future) = inner.futures.get(&future_id) { + future.mark_cancelled(); + } + } + + was_pending + } + + /// Get a future by ID + pub fn get_future(&self, future_id: u64) -> Option { + self.inner.lock().futures.get(&future_id).cloned() + } + + /// Check if there's pending work + pub fn has_pending(&self) -> bool { + let inner = self.inner.lock(); + !inner.pending.get(&self.stream_id).map(|v| v.is_empty()).unwrap_or(true) + } + + /// Check if there's running work + pub fn has_running(&self) -> bool { + let inner = self.inner.lock(); + !inner.running.get(&self.stream_id).map(|v| v.is_empty()).unwrap_or(true) + } + + /// Get statistics + pub fn stats(&self) -> AsyncExecStats { + let inner = self.inner.lock(); + + let pending_count = inner.pending.get(&self.stream_id).map(|v| v.len()).unwrap_or(0); + let running_count = inner.running.get(&self.stream_id).map(|v| v.len()).unwrap_or(0); + + let avg_exec = if inner.completed_count > 0 { + inner.total_exec_time / inner.completed_count as f64 + } else { + 0.0 + }; + + AsyncExecStats { + total_dispatched: self.next_id.load(Ordering::SeqCst) - 1, + pending_count, + running_count, + completed_count: inner.completed_count, + failed_count: inner.failed_count, + cancelled_count: inner.cancelled_count, + avg_exec_time: avg_exec, + } + } + + /// Garbage collect completed futures + pub fn gc(&self) { + let mut inner = self.inner.lock(); + inner.futures.retain(|_, f| !f.state().is_terminal()); + } + + /// Clear all state + pub fn clear(&self) { + let mut inner = self.inner.lock(); + inner.futures.clear(); + inner.pending.clear(); + inner.running.clear(); + inner.total_exec_time = 0.0; + inner.completed_count = 0; + inner.failed_count = 0; + inner.cancelled_count = 0; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_kernel_future_creation() { + let future = KernelFuture::new(1, 0, "test".into()); + assert_eq!(future.id(), 1); + assert_eq!(future.stream_id(), 0); + assert_eq!(future.state(), FutureState::Pending); + assert!(!future.is_ready()); + } + + #[test] + fn test_future_completion() { + let future = KernelFuture::new(1, 0, "test".into()); + + future.mark_launched(); + assert_eq!(future.state(), FutureState::Running); + + future.mark_completed(KernelResult::success(0.1)); + assert_eq!(future.state(), FutureState::Completed); + assert!(future.is_ready()); + + let result = future.wait(); + assert!(result.success); + } + + #[test] + fn test_future_failure() { + let future = KernelFuture::new(1, 0, "test".into()); + + future.mark_launched(); + future.mark_failed("CUDA error".into()); + + assert_eq!(future.state(), FutureState::Failed); + assert!(future.is_ready()); + + let result = future.wait(); + assert!(!result.success); + assert_eq!(result.error, Some("CUDA error".into())); + } + + #[test] + fn test_executor_dispatch() { + let executor = AsyncExecutor::new("llm".into(), 0); + + let request = AsyncKernelRequest::linear(0x1000, 1024, 256); + let future = executor.dispatch(request); + + assert_eq!(future.state(), FutureState::Pending); + assert!(executor.has_pending()); + + let pending = executor.get_pending(); + assert_eq!(pending.len(), 1); + } + + #[test] + fn test_executor_lifecycle() { + let executor = AsyncExecutor::new("llm".into(), 0); + + let request = AsyncKernelRequest::new(0x1000); + let future = executor.dispatch(request); + let id = future.id(); + + // Launch + executor.mark_launched(id); + assert!(!executor.has_pending()); + assert!(executor.has_running()); + assert_eq!(future.state(), FutureState::Running); + + // Complete + executor.mark_completed(id, 0.05); + assert!(!executor.has_running()); + assert!(future.is_ready()); + + let stats = executor.stats(); + assert_eq!(stats.completed_count, 1); + } + + #[test] + fn test_executor_cancel() { + let executor = AsyncExecutor::new("llm".into(), 0); + + let request = AsyncKernelRequest::new(0x1000); + let future = executor.dispatch(request); + let id = future.id(); + + assert!(executor.cancel(id)); + assert_eq!(future.state(), FutureState::Cancelled); + + let stats = executor.stats(); + assert_eq!(stats.cancelled_count, 1); + } + + #[test] + fn test_multiple_dispatches() { + let executor = AsyncExecutor::new("llm".into(), 0); + + let f1 = executor.dispatch(AsyncKernelRequest::new(0x1000)); + let f2 = executor.dispatch(AsyncKernelRequest::new(0x2000)); + let f3 = executor.dispatch(AsyncKernelRequest::new(0x3000)); + + assert_eq!(executor.get_pending().len(), 3); + + executor.mark_launched(f1.id()); + executor.mark_launched(f2.id()); + + assert_eq!(executor.get_pending().len(), 1); + assert!(executor.has_running()); + + executor.mark_completed(f1.id(), 0.1); + executor.mark_completed(f2.id(), 0.2); + executor.mark_launched(f3.id()); + executor.mark_completed(f3.id(), 0.3); + + let stats = executor.stats(); + assert_eq!(stats.completed_count, 3); + assert!((stats.avg_exec_time - 0.2).abs() < 0.01); + } +} diff --git a/rust/pygpukit-core/src/scheduler/dispatch_controller.rs b/rust/pygpukit-core/src/scheduler/dispatch_controller.rs new file mode 100644 index 0000000..3a91bcd --- /dev/null +++ b/rust/pygpukit-core/src/scheduler/dispatch_controller.rs @@ -0,0 +1,704 @@ +//! Multi-LLM Dispatch Controller +//! +//! Manages multiple LLM execution contexts on a single GPU: +//! - Stream pool for multi-LLM execution +//! - Execution context lifecycle management +//! - Global VRAM budget tracking +//! - Session management + +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, Ordering}; +use parking_lot::RwLock; + +use super::execution_context::{ExecutionContext, ContextState, ContextStats}; +use super::async_exec::{KernelFuture, AsyncKernelRequest, AsyncExecStats}; + +/// Controller statistics +#[derive(Debug, Clone, Default)] +pub struct ControllerStats { + /// Whether controller is initialized + pub initialized: bool, + /// Device ID + pub device_id: i32, + /// Total VRAM budget + pub total_vram_budget: usize, + /// Device total memory + pub device_total_memory: usize, + /// Total VRAM used across all contexts + pub used_vram: usize, + /// Available VRAM + pub available_vram: usize, + /// Number of active contexts + pub context_count: usize, + /// Number of streams in pool + pub stream_pool_size: usize, +} + +/// Internal controller state +struct ControllerInner { + /// Device ID + device_id: i32, + /// Total VRAM budget for all contexts + total_vram_budget: usize, + /// Device total memory (from CUDA) + device_total_memory: usize, + /// Execution contexts by LLM ID + contexts: HashMap, + /// Available stream IDs (simple pool) + available_streams: Vec, + /// Next stream ID to allocate if pool empty + next_stream_id: u32, +} + +/// Multi-LLM Dispatch Controller +/// +/// Manages execution contexts for multiple LLM instances on a single GPU. +/// Uses stream-based isolation for concurrent execution. +/// +/// # Example +/// +/// ``` +/// use pygpukit_core::scheduler::{MultiLLMController, ContextState}; +/// +/// let controller = MultiLLMController::new(); +/// // Initialize with device_id=0, device_total_memory=8GB, total_vram_budget=8GB +/// controller.initialize(0, 8 * 1024 * 1024 * 1024, 8 * 1024 * 1024 * 1024); +/// +/// // Create context for first LLM with 4GB budget +/// let stream_id = controller.create_context("gpt2_a", 4 * 1024 * 1024 * 1024).unwrap(); +/// +/// // Create context for second LLM with 4GB budget +/// let stream_id2 = controller.create_context("gpt2_b", 4 * 1024 * 1024 * 1024).unwrap(); +/// +/// // Start session +/// controller.start_session(); +/// // ... execute kernels ... +/// controller.end_session(); +/// ``` +pub struct MultiLLMController { + /// Whether controller is initialized + initialized: AtomicBool, + /// Whether a session is active + session_active: AtomicBool, + /// Internal state + inner: RwLock, +} + +impl MultiLLMController { + /// Create a new controller (uninitialized) + pub fn new() -> Self { + Self { + initialized: AtomicBool::new(false), + session_active: AtomicBool::new(false), + inner: RwLock::new(ControllerInner { + device_id: 0, + total_vram_budget: 0, + device_total_memory: 0, + contexts: HashMap::new(), + available_streams: Vec::new(), + next_stream_id: 0, + }), + } + } + + /// Initialize the controller + /// + /// # Arguments + /// + /// * `device_id` - CUDA device ID + /// * `total_vram_budget` - Total VRAM budget for all contexts (0 = device total) + /// + /// Note: This does NOT call CUDA APIs directly. The caller should: + /// 1. Initialize CUDA driver context via C++ + /// 2. Get device total memory via C++ + /// 3. Call this with the device info + pub fn initialize(&self, device_id: i32, device_total_memory: usize, total_vram_budget: usize) { + let mut inner = self.inner.write(); + + inner.device_id = device_id; + inner.device_total_memory = device_total_memory; + inner.total_vram_budget = if total_vram_budget == 0 || total_vram_budget > device_total_memory { + device_total_memory + } else { + total_vram_budget + }; + + // Pre-allocate stream IDs (actual CUDA streams created by C++) + inner.available_streams = (0..8).collect(); + inner.next_stream_id = 8; + + self.initialized.store(true, Ordering::SeqCst); + } + + /// Check if controller is initialized + #[inline] + pub fn is_initialized(&self) -> bool { + self.initialized.load(Ordering::SeqCst) + } + + /// Create an execution context for an LLM + /// + /// # Arguments + /// + /// * `llm_id` - Unique LLM identifier + /// * `max_vram` - Maximum VRAM for this LLM (0 = share global budget) + /// + /// # Returns + /// + /// The assigned stream ID for this context + /// + /// # Panics + /// + /// Panics if controller is not initialized or llm_id already exists + pub fn create_context(&self, llm_id: &str, max_vram: usize) -> Result { + if !self.is_initialized() { + return Err("Controller not initialized".to_string()); + } + + let mut inner = self.inner.write(); + + // Check if context already exists + if inner.contexts.contains_key(llm_id) { + return Err(format!("Context already exists for LLM: {}", llm_id)); + } + + // Acquire a stream ID + let stream_id = inner.available_streams.pop().unwrap_or_else(|| { + let id = inner.next_stream_id; + inner.next_stream_id += 1; + id + }); + + // Create context + let context = ExecutionContext::new(llm_id.to_string(), stream_id, max_vram); + inner.contexts.insert(llm_id.to_string(), context); + + Ok(stream_id) + } + + /// Get an execution context by LLM ID + pub fn get_context(&self, llm_id: &str) -> Option { + let inner = self.inner.read(); + inner.contexts.get(llm_id).map(ContextStats::from) + } + + /// Get mutable access to context for state changes + pub fn with_context_mut(&self, llm_id: &str, f: F) -> Option + where + F: FnOnce(&mut ExecutionContext) -> R, + { + let mut inner = self.inner.write(); + inner.contexts.get_mut(llm_id).map(f) + } + + /// Destroy an execution context + pub fn destroy_context(&self, llm_id: &str) -> bool { + let mut inner = self.inner.write(); + + if let Some(ctx) = inner.contexts.remove(llm_id) { + // Return stream ID to pool + inner.available_streams.push(ctx.stream_id()); + true + } else { + false + } + } + + /// List all active context IDs + pub fn list_contexts(&self) -> Vec { + let inner = self.inner.read(); + inner.contexts.keys().cloned().collect() + } + + /// Get number of active contexts + pub fn context_count(&self) -> usize { + self.inner.read().contexts.len() + } + + /// Get stream ID for a context + pub fn get_stream_id(&self, llm_id: &str) -> Option { + self.inner.read().contexts.get(llm_id).map(|c| c.stream_id()) + } + + // --- Memory Tracking --- + + /// Track a memory allocation for a context + pub fn track_allocation(&self, llm_id: &str, buffer_id: u64, size: usize) -> bool { + let inner = self.inner.read(); + if let Some(ctx) = inner.contexts.get(llm_id) { + ctx.track_allocation(buffer_id, size) + } else { + false + } + } + + /// Track a memory deallocation for a context + pub fn track_deallocation(&self, llm_id: &str, buffer_id: u64) { + let inner = self.inner.read(); + if let Some(ctx) = inner.contexts.get(llm_id) { + ctx.track_deallocation(buffer_id); + } + } + + /// Get total VRAM used across all contexts + pub fn used_vram(&self) -> usize { + let inner = self.inner.read(); + inner.contexts.values().map(|c| c.used_vram()).sum() + } + + /// Get available VRAM (global budget - used) + pub fn available_vram(&self) -> usize { + let inner = self.inner.read(); + inner.total_vram_budget.saturating_sub( + inner.contexts.values().map(|c| c.used_vram()).sum() + ) + } + + // --- Session Management --- + + /// Start a session (mark all contexts as running) + pub fn start_session(&self) { + if self.session_active.swap(true, Ordering::SeqCst) { + return; // Already active + } + + let mut inner = self.inner.write(); + for ctx in inner.contexts.values_mut() { + if ctx.state() == ContextState::Idle { + ctx.start(); + } + } + } + + /// End a session (mark all contexts as idle) + pub fn end_session(&self) { + if !self.session_active.swap(false, Ordering::SeqCst) { + return; // Not active + } + + let mut inner = self.inner.write(); + for ctx in inner.contexts.values_mut() { + ctx.stop(); + } + } + + /// Check if a session is active + #[inline] + pub fn is_session_active(&self) -> bool { + self.session_active.load(Ordering::SeqCst) + } + + // --- Statistics --- + + /// Get controller statistics + pub fn stats(&self) -> ControllerStats { + let inner = self.inner.read(); + let used = inner.contexts.values().map(|c| c.used_vram()).sum(); + + ControllerStats { + initialized: self.is_initialized(), + device_id: inner.device_id, + total_vram_budget: inner.total_vram_budget, + device_total_memory: inner.device_total_memory, + used_vram: used, + available_vram: inner.total_vram_budget.saturating_sub(used), + context_count: inner.contexts.len(), + stream_pool_size: inner.available_streams.len() + inner.contexts.len(), + } + } + + /// Reset the controller (destroy all contexts) + pub fn reset(&self) { + self.session_active.store(false, Ordering::SeqCst); + + let mut inner = self.inner.write(); + + // Collect stream IDs first to avoid borrow conflict + let stream_ids: Vec = inner.contexts.values().map(|c| c.stream_id()).collect(); + + // Return all stream IDs to pool + for stream_id in stream_ids { + inner.available_streams.push(stream_id); + } + + inner.contexts.clear(); + } + + // --- Async Execution --- + + /// Dispatch an async kernel for a specific LLM context + /// + /// Returns a KernelFuture that can be used to wait for completion. + /// + /// # Arguments + /// + /// * `llm_id` - LLM identifier + /// * `request` - Kernel dispatch request + /// + /// # Returns + /// + /// KernelFuture for tracking execution, or error if context not found + pub fn dispatch_async(&self, llm_id: &str, request: AsyncKernelRequest) -> Result { + let inner = self.inner.read(); + let ctx = inner.contexts.get(llm_id) + .ok_or_else(|| format!("Context not found: {}", llm_id))?; + + Ok(ctx.dispatch_async(request)) + } + + /// Get pending futures for a context + pub fn get_pending_futures(&self, llm_id: &str) -> Option> { + let inner = self.inner.read(); + inner.contexts.get(llm_id).map(|c| c.get_pending_futures()) + } + + /// Mark a future as launched + pub fn mark_future_launched(&self, llm_id: &str, future_id: u64) { + let inner = self.inner.read(); + if let Some(ctx) = inner.contexts.get(llm_id) { + ctx.mark_future_launched(future_id); + } + } + + /// Mark a future as completed + pub fn mark_future_completed(&self, llm_id: &str, future_id: u64, exec_time: f64) { + let inner = self.inner.read(); + if let Some(ctx) = inner.contexts.get(llm_id) { + ctx.mark_future_completed(future_id, exec_time); + } + } + + /// Mark a future as failed + pub fn mark_future_failed(&self, llm_id: &str, future_id: u64, error: String) { + let inner = self.inner.read(); + if let Some(ctx) = inner.contexts.get(llm_id) { + ctx.mark_future_failed(future_id, error); + } + } + + /// Cancel a pending future + pub fn cancel_future(&self, llm_id: &str, future_id: u64) -> bool { + let inner = self.inner.read(); + inner.contexts.get(llm_id) + .map(|c| c.cancel_future(future_id)) + .unwrap_or(false) + } + + /// Get a future by ID from a context + pub fn get_future(&self, llm_id: &str, future_id: u64) -> Option { + let inner = self.inner.read(); + inner.contexts.get(llm_id) + .and_then(|c| c.get_future(future_id)) + } + + /// Get async execution stats for a context + pub fn async_stats(&self, llm_id: &str) -> Option { + let inner = self.inner.read(); + inner.contexts.get(llm_id).map(|c| c.async_stats()) + } + + // --- Per-Context Session Management --- + + /// Start a session for a specific context + /// + /// Unlike the global session, per-context sessions allow independent + /// LLM execution. Each context can have its own session lifecycle. + pub fn start_context_session(&self, llm_id: &str) -> bool { + let inner = self.inner.read(); + if let Some(ctx) = inner.contexts.get(llm_id) { + ctx.start_session(); + true + } else { + false + } + } + + /// End a session for a specific context + pub fn end_context_session(&self, llm_id: &str) -> bool { + let inner = self.inner.read(); + if let Some(ctx) = inner.contexts.get(llm_id) { + ctx.end_session(); + true + } else { + false + } + } + + /// Check if a specific context has an active session + pub fn is_context_session_active(&self, llm_id: &str) -> Option { + let inner = self.inner.read(); + inner.contexts.get(llm_id).map(|c| c.is_session_active()) + } +} + +impl Default for MultiLLMController { + fn default() -> Self { + Self::new() + } +} + +// Thread-safe +unsafe impl Send for MultiLLMController {} +unsafe impl Sync for MultiLLMController {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_controller_creation() { + let controller = MultiLLMController::new(); + assert!(!controller.is_initialized()); + } + + #[test] + fn test_initialization() { + let controller = MultiLLMController::new(); + controller.initialize(0, 8_000_000_000, 0); + + assert!(controller.is_initialized()); + let stats = controller.stats(); + assert_eq!(stats.device_id, 0); + assert_eq!(stats.total_vram_budget, 8_000_000_000); + } + + #[test] + fn test_create_context() { + let controller = MultiLLMController::new(); + controller.initialize(0, 8_000_000_000, 0); + + let stream_id = controller.create_context("gpt2_a", 4_000_000_000).unwrap(); + assert!(stream_id < 8); // From pre-allocated pool + + let ctx = controller.get_context("gpt2_a").unwrap(); + assert_eq!(ctx.llm_id, "gpt2_a"); + assert_eq!(ctx.stream_id, stream_id); + assert_eq!(ctx.max_vram, 4_000_000_000); + } + + #[test] + fn test_multiple_contexts() { + let controller = MultiLLMController::new(); + controller.initialize(0, 8_000_000_000, 0); + + let s1 = controller.create_context("gpt2_a", 2_000_000_000).unwrap(); + let s2 = controller.create_context("gpt2_b", 2_000_000_000).unwrap(); + let s3 = controller.create_context("llama", 2_000_000_000).unwrap(); + + // All should have different stream IDs + assert_ne!(s1, s2); + assert_ne!(s2, s3); + + assert_eq!(controller.context_count(), 3); + + let ids = controller.list_contexts(); + assert!(ids.contains(&"gpt2_a".to_string())); + assert!(ids.contains(&"gpt2_b".to_string())); + assert!(ids.contains(&"llama".to_string())); + } + + #[test] + fn test_duplicate_context_error() { + let controller = MultiLLMController::new(); + controller.initialize(0, 8_000_000_000, 0); + + controller.create_context("gpt2", 0).unwrap(); + let result = controller.create_context("gpt2", 0); + + assert!(result.is_err()); + assert!(result.unwrap_err().contains("already exists")); + } + + #[test] + fn test_destroy_context() { + let controller = MultiLLMController::new(); + controller.initialize(0, 8_000_000_000, 0); + + let stream_id = controller.create_context("gpt2", 0).unwrap(); + assert!(controller.destroy_context("gpt2")); + assert!(controller.get_context("gpt2").is_none()); + + // Stream ID should be reusable + let new_stream = controller.create_context("llama", 0).unwrap(); + assert_eq!(new_stream, stream_id); + } + + #[test] + fn test_memory_tracking() { + let controller = MultiLLMController::new(); + controller.initialize(0, 8_000_000_000, 0); + + controller.create_context("gpt2", 1_000_000).unwrap(); + + assert!(controller.track_allocation("gpt2", 1, 500_000)); + assert!(controller.track_allocation("gpt2", 2, 400_000)); + + let ctx = controller.get_context("gpt2").unwrap(); + assert_eq!(ctx.used_vram, 900_000); + + // Should fail - exceeds per-context budget + assert!(!controller.track_allocation("gpt2", 3, 200_000)); + + // Deallocate and retry + controller.track_deallocation("gpt2", 1); + assert!(controller.track_allocation("gpt2", 3, 200_000)); + } + + #[test] + fn test_session_management() { + let controller = MultiLLMController::new(); + controller.initialize(0, 8_000_000_000, 0); + + controller.create_context("gpt2", 0).unwrap(); + + assert!(!controller.is_session_active()); + + controller.start_session(); + assert!(controller.is_session_active()); + + let ctx = controller.get_context("gpt2").unwrap(); + assert_eq!(ctx.state, ContextState::Running); + + controller.end_session(); + assert!(!controller.is_session_active()); + + let ctx = controller.get_context("gpt2").unwrap(); + assert_eq!(ctx.state, ContextState::Idle); + } + + #[test] + fn test_global_vram_tracking() { + let controller = MultiLLMController::new(); + controller.initialize(0, 1_000_000, 0); + + controller.create_context("gpt2_a", 0).unwrap(); + controller.create_context("gpt2_b", 0).unwrap(); + + controller.track_allocation("gpt2_a", 1, 300_000); + controller.track_allocation("gpt2_b", 2, 200_000); + + assert_eq!(controller.used_vram(), 500_000); + assert_eq!(controller.available_vram(), 500_000); + } + + #[test] + fn test_reset() { + let controller = MultiLLMController::new(); + controller.initialize(0, 8_000_000_000, 0); + + controller.create_context("gpt2_a", 0).unwrap(); + controller.create_context("gpt2_b", 0).unwrap(); + controller.start_session(); + + controller.reset(); + + assert!(!controller.is_session_active()); + assert_eq!(controller.context_count(), 0); + } + + #[test] + fn test_async_dispatch() { + let controller = MultiLLMController::new(); + controller.initialize(0, 8_000_000_000, 0); + + controller.create_context("tts", 0).unwrap(); + + let request = AsyncKernelRequest::new(0x1000); + let future = controller.dispatch_async("tts", request).unwrap(); + + assert_eq!(future.context_id(), "tts"); + assert!(!future.is_ready()); + + let pending = controller.get_pending_futures("tts").unwrap(); + assert_eq!(pending.len(), 1); + } + + #[test] + fn test_async_lifecycle() { + let controller = MultiLLMController::new(); + controller.initialize(0, 8_000_000_000, 0); + + controller.create_context("llm", 0).unwrap(); + + let request = AsyncKernelRequest::linear(0x2000, 1024, 256); + let future = controller.dispatch_async("llm", request).unwrap(); + let id = future.id(); + + // Launch + controller.mark_future_launched("llm", id); + assert_eq!(controller.get_pending_futures("llm").unwrap().len(), 0); + + // Complete + controller.mark_future_completed("llm", id, 0.05); + assert!(future.is_ready()); + + let stats = controller.async_stats("llm").unwrap(); + assert_eq!(stats.completed_count, 1); + } + + #[test] + fn test_per_context_session() { + let controller = MultiLLMController::new(); + controller.initialize(0, 8_000_000_000, 0); + + controller.create_context("tts", 0).unwrap(); + controller.create_context("llm", 0).unwrap(); + + // Start session for TTS only + assert!(controller.start_context_session("tts")); + assert_eq!(controller.is_context_session_active("tts"), Some(true)); + assert_eq!(controller.is_context_session_active("llm"), Some(false)); + + // Start session for LLM + assert!(controller.start_context_session("llm")); + assert_eq!(controller.is_context_session_active("llm"), Some(true)); + + // End TTS session, LLM continues + assert!(controller.end_context_session("tts")); + assert_eq!(controller.is_context_session_active("tts"), Some(false)); + assert_eq!(controller.is_context_session_active("llm"), Some(true)); + } + + #[test] + fn test_multi_context_async_dispatch() { + let controller = MultiLLMController::new(); + controller.initialize(0, 8_000_000_000, 0); + + controller.create_context("tts", 0).unwrap(); + controller.create_context("llm", 0).unwrap(); + controller.create_context("vision", 0).unwrap(); + + // Start independent sessions + controller.start_context_session("tts"); + controller.start_context_session("llm"); + controller.start_context_session("vision"); + + // Dispatch kernels to different contexts + let tts_future = controller.dispatch_async("tts", AsyncKernelRequest::new(0x1000)).unwrap(); + let llm_future = controller.dispatch_async("llm", AsyncKernelRequest::new(0x2000)).unwrap(); + let vision_future = controller.dispatch_async("vision", AsyncKernelRequest::new(0x3000)).unwrap(); + + // Each context has exactly one pending + assert_eq!(controller.get_pending_futures("tts").unwrap().len(), 1); + assert_eq!(controller.get_pending_futures("llm").unwrap().len(), 1); + assert_eq!(controller.get_pending_futures("vision").unwrap().len(), 1); + + // Complete them in different order + controller.mark_future_launched("llm", llm_future.id()); + controller.mark_future_completed("llm", llm_future.id(), 0.1); + + controller.mark_future_launched("tts", tts_future.id()); + controller.mark_future_completed("tts", tts_future.id(), 0.05); + + controller.mark_future_launched("vision", vision_future.id()); + controller.mark_future_completed("vision", vision_future.id(), 0.2); + + // All should be ready + assert!(tts_future.is_ready()); + assert!(llm_future.is_ready()); + assert!(vision_future.is_ready()); + + // Check completion times + assert!(tts_future.wait().exec_time < llm_future.wait().exec_time); + } +} diff --git a/rust/pygpukit-core/src/scheduler/execution_context.rs b/rust/pygpukit-core/src/scheduler/execution_context.rs new file mode 100644 index 0000000..31e34bf --- /dev/null +++ b/rust/pygpukit-core/src/scheduler/execution_context.rs @@ -0,0 +1,530 @@ +//! Execution Context for Multi-LLM Scheduling +//! +//! Provides per-LLM execution context with: +//! - Dedicated stream ID for kernel isolation +//! - Memory budget tracking +//! - State management (IDLE, RUNNING, PAUSED) +//! - Async kernel execution with KernelFuture +//! - Per-context session management +//! +//! Each LLM instance is bound to exactly one ExecutionContext. + +use std::sync::atomic::{AtomicUsize, AtomicBool, Ordering}; +use std::collections::HashMap; +use parking_lot::RwLock; + +use super::async_exec::{AsyncExecutor, AsyncKernelRequest, KernelFuture, AsyncExecStats}; + +/// State of an execution context +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ContextState { + /// Context created but not running + Idle = 0, + /// Context is actively executing kernels + Running = 1, + /// Context is paused (e.g., waiting for memory) + Paused = 2, +} + +impl Default for ContextState { + fn default() -> Self { + ContextState::Idle + } +} + +/// Per-LLM Execution Context +/// +/// Each LLM instance is bound to exactly one ExecutionContext. +/// Provides: +/// - Dedicated stream ID for kernel isolation +/// - Memory budget tracking +/// - State management +/// - Async kernel execution +/// - Per-context session +pub struct ExecutionContext { + /// Unique identifier for the LLM instance + llm_id: String, + /// Current state + state: ContextState, + /// Assigned stream ID (managed by C++ StreamPool) + stream_id: u32, + /// Maximum VRAM budget in bytes (0 = unlimited) + max_vram: usize, + /// Currently used VRAM + used_vram: AtomicUsize, + /// Allocated buffer tracking: buffer_id -> size + allocated_buffers: RwLock>, + /// Async kernel executor + executor: AsyncExecutor, + /// Per-context session active flag + session_active: AtomicBool, +} + +impl ExecutionContext { + /// Create a new execution context + /// + /// # Arguments + /// + /// * `llm_id` - Unique identifier for the LLM instance + /// * `stream_id` - Assigned stream ID + /// * `max_vram` - Maximum VRAM budget in bytes (0 = unlimited) + pub fn new(llm_id: String, stream_id: u32, max_vram: usize) -> Self { + let executor = AsyncExecutor::new(llm_id.clone(), stream_id); + Self { + llm_id, + state: ContextState::Idle, + stream_id, + max_vram, + used_vram: AtomicUsize::new(0), + allocated_buffers: RwLock::new(HashMap::new()), + executor, + session_active: AtomicBool::new(false), + } + } + + // --- Accessors --- + + /// Get the LLM ID + #[inline] + pub fn llm_id(&self) -> &str { + &self.llm_id + } + + /// Get current state + #[inline] + pub fn state(&self) -> ContextState { + self.state + } + + /// Get assigned stream ID + #[inline] + pub fn stream_id(&self) -> u32 { + self.stream_id + } + + /// Get maximum VRAM budget + #[inline] + pub fn max_vram(&self) -> usize { + self.max_vram + } + + /// Get currently used VRAM + #[inline] + pub fn used_vram(&self) -> usize { + self.used_vram.load(Ordering::SeqCst) + } + + /// Get available VRAM + #[inline] + pub fn available_vram(&self) -> usize { + if self.max_vram == 0 { + usize::MAX + } else { + self.max_vram.saturating_sub(self.used_vram()) + } + } + + /// Get number of allocated buffers + pub fn buffer_count(&self) -> usize { + self.allocated_buffers.read().len() + } + + // --- State Management --- + + /// Set context state + pub fn set_state(&mut self, state: ContextState) { + self.state = state; + } + + /// Start the context (set to Running) + pub fn start(&mut self) { + self.state = ContextState::Running; + } + + /// Pause the context + pub fn pause(&mut self) { + self.state = ContextState::Paused; + } + + /// Stop the context (set to Idle) + pub fn stop(&mut self) { + self.state = ContextState::Idle; + } + + /// Check if context is running + #[inline] + pub fn is_running(&self) -> bool { + self.state == ContextState::Running + } + + // --- Memory Tracking --- + + /// Check if allocation fits within budget + pub fn can_allocate(&self, size: usize) -> bool { + if self.max_vram == 0 { + return true; // Unlimited + } + self.used_vram() + size <= self.max_vram + } + + /// Track a memory allocation + /// + /// # Arguments + /// + /// * `buffer_id` - Unique buffer identifier + /// * `size` - Size in bytes + /// + /// # Returns + /// + /// `true` if allocation fits within budget, `false` otherwise + pub fn track_allocation(&self, buffer_id: u64, size: usize) -> bool { + if !self.can_allocate(size) { + return false; + } + + let mut buffers = self.allocated_buffers.write(); + buffers.insert(buffer_id, size); + self.used_vram.fetch_add(size, Ordering::SeqCst); + true + } + + /// Track a memory deallocation + /// + /// # Arguments + /// + /// * `buffer_id` - Buffer identifier to deallocate + pub fn track_deallocation(&self, buffer_id: u64) { + let mut buffers = self.allocated_buffers.write(); + if let Some(size) = buffers.remove(&buffer_id) { + // Saturating sub to handle potential underflow + let current = self.used_vram.load(Ordering::SeqCst); + let new_val = current.saturating_sub(size); + self.used_vram.store(new_val, Ordering::SeqCst); + } + } + + /// Get size of a specific buffer + pub fn get_buffer_size(&self, buffer_id: u64) -> Option { + self.allocated_buffers.read().get(&buffer_id).copied() + } + + /// Clear all tracked allocations + pub fn clear_allocations(&self) { + let mut buffers = self.allocated_buffers.write(); + buffers.clear(); + self.used_vram.store(0, Ordering::SeqCst); + } + + // --- Async Execution --- + + /// Dispatch an async kernel + /// + /// Returns a KernelFuture that can be used to wait for completion. + /// The kernel is queued for execution on this context's stream. + /// + /// # Example + /// + /// ```ignore + /// let request = AsyncKernelRequest::linear(kernel_handle, n_elements, 256); + /// let future = ctx.dispatch_async(request); + /// + /// // Do other work... + /// + /// let result = future.wait(); + /// ``` + pub fn dispatch_async(&self, request: AsyncKernelRequest) -> KernelFuture { + self.executor.dispatch(request) + } + + /// Get pending futures for this context + pub fn get_pending_futures(&self) -> Vec { + self.executor.get_pending() + } + + /// Mark a future as launched + pub fn mark_future_launched(&self, future_id: u64) { + self.executor.mark_launched(future_id); + } + + /// Mark a future as completed + pub fn mark_future_completed(&self, future_id: u64, exec_time: f64) { + self.executor.mark_completed(future_id, exec_time); + } + + /// Mark a future as failed + pub fn mark_future_failed(&self, future_id: u64, error: String) { + self.executor.mark_failed(future_id, error); + } + + /// Cancel a pending future + pub fn cancel_future(&self, future_id: u64) -> bool { + self.executor.cancel(future_id) + } + + /// Get a future by ID + pub fn get_future(&self, future_id: u64) -> Option { + self.executor.get_future(future_id) + } + + /// Check if there are pending kernels + pub fn has_pending_kernels(&self) -> bool { + self.executor.has_pending() + } + + /// Check if there are running kernels + pub fn has_running_kernels(&self) -> bool { + self.executor.has_running() + } + + /// Get async execution statistics + pub fn async_stats(&self) -> AsyncExecStats { + self.executor.stats() + } + + // --- Per-Context Session --- + + /// Start a session for this context + /// + /// Unlike the global session, per-context sessions allow independent + /// LLM execution. Each context can have its own session lifecycle. + pub fn start_session(&self) { + self.session_active.store(true, Ordering::SeqCst); + } + + /// End the session for this context + /// + /// This does NOT synchronize the stream - use `sync()` for that. + /// It just marks the session as inactive. + pub fn end_session(&self) { + self.session_active.store(false, Ordering::SeqCst); + } + + /// Check if a session is active for this context + pub fn is_session_active(&self) -> bool { + self.session_active.load(Ordering::SeqCst) + } + + /// Garbage collect completed futures + pub fn gc_futures(&self) { + self.executor.gc(); + } + + /// Clear all async state + pub fn clear_async_state(&self) { + self.executor.clear(); + } +} + +/// Execution context statistics +#[derive(Debug, Clone, Default)] +pub struct ContextStats { + /// LLM ID + pub llm_id: String, + /// Current state + pub state: ContextState, + /// Assigned stream ID + pub stream_id: u32, + /// Maximum VRAM budget + pub max_vram: usize, + /// Currently used VRAM + pub used_vram: usize, + /// Available VRAM + pub available_vram: usize, + /// Number of allocated buffers + pub buffer_count: usize, +} + +impl From<&ExecutionContext> for ContextStats { + fn from(ctx: &ExecutionContext) -> Self { + Self { + llm_id: ctx.llm_id.clone(), + state: ctx.state, + stream_id: ctx.stream_id, + max_vram: ctx.max_vram, + used_vram: ctx.used_vram(), + available_vram: ctx.available_vram(), + buffer_count: ctx.buffer_count(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_context_creation() { + let ctx = ExecutionContext::new("gpt2".to_string(), 0, 1024 * 1024); + assert_eq!(ctx.llm_id(), "gpt2"); + assert_eq!(ctx.stream_id(), 0); + assert_eq!(ctx.max_vram(), 1024 * 1024); + assert_eq!(ctx.used_vram(), 0); + assert_eq!(ctx.state(), ContextState::Idle); + } + + #[test] + fn test_state_transitions() { + let mut ctx = ExecutionContext::new("gpt2".to_string(), 0, 0); + + assert_eq!(ctx.state(), ContextState::Idle); + + ctx.start(); + assert_eq!(ctx.state(), ContextState::Running); + assert!(ctx.is_running()); + + ctx.pause(); + assert_eq!(ctx.state(), ContextState::Paused); + + ctx.stop(); + assert_eq!(ctx.state(), ContextState::Idle); + } + + #[test] + fn test_memory_tracking() { + let ctx = ExecutionContext::new("gpt2".to_string(), 0, 1000); + + // Track allocations + assert!(ctx.track_allocation(1, 400)); + assert_eq!(ctx.used_vram(), 400); + + assert!(ctx.track_allocation(2, 400)); + assert_eq!(ctx.used_vram(), 800); + + // Should fail - exceeds budget + assert!(!ctx.track_allocation(3, 300)); + assert_eq!(ctx.used_vram(), 800); + + // Deallocate + ctx.track_deallocation(1); + assert_eq!(ctx.used_vram(), 400); + + // Now should succeed + assert!(ctx.track_allocation(3, 300)); + assert_eq!(ctx.used_vram(), 700); + } + + #[test] + fn test_unlimited_budget() { + let ctx = ExecutionContext::new("gpt2".to_string(), 0, 0); + + // Should always succeed with unlimited budget + assert!(ctx.can_allocate(usize::MAX / 2)); + assert!(ctx.track_allocation(1, 1_000_000_000)); + assert_eq!(ctx.available_vram(), usize::MAX); + } + + #[test] + fn test_context_stats() { + let ctx = ExecutionContext::new("gpt2".to_string(), 5, 2000); + ctx.track_allocation(1, 500); + + let stats = ContextStats::from(&ctx); + assert_eq!(stats.llm_id, "gpt2"); + assert_eq!(stats.stream_id, 5); + assert_eq!(stats.max_vram, 2000); + assert_eq!(stats.used_vram, 500); + assert_eq!(stats.available_vram, 1500); + assert_eq!(stats.buffer_count, 1); + } + + #[test] + fn test_async_dispatch() { + let ctx = ExecutionContext::new("tts".to_string(), 0, 0); + + let request = AsyncKernelRequest::new(0x1000); + let future = ctx.dispatch_async(request); + + assert!(ctx.has_pending_kernels()); + assert!(!ctx.has_running_kernels()); + + let pending = ctx.get_pending_futures(); + assert_eq!(pending.len(), 1); + assert_eq!(pending[0], future.id()); + } + + #[test] + fn test_async_lifecycle() { + let ctx = ExecutionContext::new("llm".to_string(), 1, 0); + + let request = AsyncKernelRequest::linear(0x2000, 1024, 256); + let future = ctx.dispatch_async(request); + let id = future.id(); + + // Launch + ctx.mark_future_launched(id); + assert!(!ctx.has_pending_kernels()); + assert!(ctx.has_running_kernels()); + + // Complete + ctx.mark_future_completed(id, 0.05); + assert!(!ctx.has_running_kernels()); + assert!(future.is_ready()); + + let stats = ctx.async_stats(); + assert_eq!(stats.completed_count, 1); + } + + #[test] + fn test_per_context_session() { + let ctx = ExecutionContext::new("vision".to_string(), 2, 0); + + assert!(!ctx.is_session_active()); + + ctx.start_session(); + assert!(ctx.is_session_active()); + + ctx.end_session(); + assert!(!ctx.is_session_active()); + } + + #[test] + fn test_multiple_contexts_independent_sessions() { + let tts_ctx = ExecutionContext::new("tts".to_string(), 0, 0); + let llm_ctx = ExecutionContext::new("llm".to_string(), 1, 0); + let vision_ctx = ExecutionContext::new("vision".to_string(), 2, 0); + + // Start sessions independently + tts_ctx.start_session(); + assert!(tts_ctx.is_session_active()); + assert!(!llm_ctx.is_session_active()); + assert!(!vision_ctx.is_session_active()); + + llm_ctx.start_session(); + assert!(tts_ctx.is_session_active()); + assert!(llm_ctx.is_session_active()); + assert!(!vision_ctx.is_session_active()); + + // End TTS session, others continue + tts_ctx.end_session(); + assert!(!tts_ctx.is_session_active()); + assert!(llm_ctx.is_session_active()); + + // Dispatch async kernel on LLM while session is active + let request = AsyncKernelRequest::new(0x3000); + let future = llm_ctx.dispatch_async(request); + assert!(llm_ctx.has_pending_kernels()); + + llm_ctx.mark_future_launched(future.id()); + llm_ctx.mark_future_completed(future.id(), 0.1); + + assert!(future.is_ready()); + } + + #[test] + fn test_cancel_pending_future() { + let ctx = ExecutionContext::new("test".to_string(), 0, 0); + + let f1 = ctx.dispatch_async(AsyncKernelRequest::new(0x1000)); + let f2 = ctx.dispatch_async(AsyncKernelRequest::new(0x2000)); + + assert_eq!(ctx.get_pending_futures().len(), 2); + + // Cancel first + assert!(ctx.cancel_future(f1.id())); + assert_eq!(ctx.get_pending_futures().len(), 1); + + // Can't cancel already running + ctx.mark_future_launched(f2.id()); + assert!(!ctx.cancel_future(f2.id())); + } +} diff --git a/rust/pygpukit-core/src/scheduler/mod.rs b/rust/pygpukit-core/src/scheduler/mod.rs index 38e72c3..f25caeb 100644 --- a/rust/pygpukit-core/src/scheduler/mod.rs +++ b/rust/pygpukit-core/src/scheduler/mod.rs @@ -7,12 +7,16 @@ //! - Admission control //! - QoS policy framework //! - GPU resource partitioning +//! - Multi-LLM execution contexts mod task; mod core; mod admission; mod qos; mod partition; +mod async_exec; +mod execution_context; +mod dispatch_controller; pub use task::{TaskState, TaskPolicy, TaskMeta, TaskStats}; pub use core::{Scheduler, SchedulerStats}; @@ -28,3 +32,8 @@ pub use partition::{ PartitionManager, PartitionConfig, Partition, PartitionLimits, PartitionUsage, PartitionStats, PartitionError, }; +pub use async_exec::{ + FutureState, KernelFuture, KernelResult, AsyncKernelRequest, AsyncExecStats, AsyncExecutor, +}; +pub use execution_context::{ExecutionContext, ContextState, ContextStats}; +pub use dispatch_controller::{MultiLLMController, ControllerStats}; diff --git a/rust/pygpukit-python/src/lib.rs b/rust/pygpukit-python/src/lib.rs index 5ff3cc0..6df80e2 100644 --- a/rust/pygpukit-python/src/lib.rs +++ b/rust/pygpukit-python/src/lib.rs @@ -10,6 +10,7 @@ mod scheduler; mod transfer; mod dispatch; mod device; +mod llm; /// PyGPUkit Rust module #[pymodule] @@ -39,6 +40,11 @@ fn _pygpukit_rust(m: &Bound<'_, PyModule>) -> PyResult<()> { device::register(&device_module)?; m.add_submodule(&device_module)?; + // LLM submodule + let llm_module = PyModule::new(m.py(), "llm")?; + llm::register(&llm_module)?; + m.add_submodule(&llm_module)?; + // Also export at top level for convenience m.add_class::()?; m.add_class::()?; @@ -97,9 +103,25 @@ fn _pygpukit_rust(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + // Multi-LLM Controller (v0.2.6+) + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + // Async Execution (v0.2.6+) + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; // Device capabilities m.add_class::()?; m.add_class::()?; + // LLM support + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/rust/pygpukit-python/src/llm.rs b/rust/pygpukit-python/src/llm.rs new file mode 100644 index 0000000..b768407 --- /dev/null +++ b/rust/pygpukit-python/src/llm.rs @@ -0,0 +1,298 @@ +//! Python bindings for LLM support (safetensors loader, tokenizer) + +use pyo3::prelude::*; +use pyo3::exceptions::{PyIOError, PyKeyError, PyValueError}; +use pygpukit_core::llm::{SafeTensorsFile, Dtype, SafeTensorsError, Tokenizer, TokenizerError}; +use std::sync::Arc; + +/// Convert SafeTensorsError to PyErr +fn to_py_err(e: SafeTensorsError) -> PyErr { + match e { + SafeTensorsError::IoError(e) => PyIOError::new_err(e.to_string()), + SafeTensorsError::ParseError(e) => PyValueError::new_err(e), + SafeTensorsError::TensorNotFound(name) => PyKeyError::new_err(name), + SafeTensorsError::UnsupportedDtype(dtype) => PyValueError::new_err(dtype), + } +} + +/// Python wrapper for Dtype enum +#[pyclass(name = "Dtype", eq, eq_int)] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum PyDtype { + Float32 = 0, + Float16 = 1, + BFloat16 = 2, + Float64 = 3, + Int32 = 4, + Int64 = 5, + Int16 = 6, + Int8 = 7, + UInt8 = 8, + Bool = 9, +} + +impl From for PyDtype { + fn from(dtype: Dtype) -> Self { + match dtype { + Dtype::Float32 => PyDtype::Float32, + Dtype::Float16 => PyDtype::Float16, + Dtype::BFloat16 => PyDtype::BFloat16, + Dtype::Float64 => PyDtype::Float64, + Dtype::Int32 => PyDtype::Int32, + Dtype::Int64 => PyDtype::Int64, + Dtype::Int16 => PyDtype::Int16, + Dtype::Int8 => PyDtype::Int8, + Dtype::UInt8 => PyDtype::UInt8, + Dtype::Bool => PyDtype::Bool, + } + } +} + +#[pymethods] +impl PyDtype { + /// Size in bytes of a single element + #[getter] + fn element_size(&self) -> usize { + match self { + PyDtype::Float64 | PyDtype::Int64 => 8, + PyDtype::Float32 | PyDtype::Int32 => 4, + PyDtype::Float16 | PyDtype::BFloat16 | PyDtype::Int16 => 2, + PyDtype::Int8 | PyDtype::UInt8 | PyDtype::Bool => 1, + } + } + + fn __repr__(&self) -> &'static str { + match self { + PyDtype::Float32 => "Dtype.Float32", + PyDtype::Float16 => "Dtype.Float16", + PyDtype::BFloat16 => "Dtype.BFloat16", + PyDtype::Float64 => "Dtype.Float64", + PyDtype::Int32 => "Dtype.Int32", + PyDtype::Int64 => "Dtype.Int64", + PyDtype::Int16 => "Dtype.Int16", + PyDtype::Int8 => "Dtype.Int8", + PyDtype::UInt8 => "Dtype.UInt8", + PyDtype::Bool => "Dtype.Bool", + } + } +} + +/// Metadata for a single tensor +#[pyclass(name = "TensorInfo")] +#[derive(Clone)] +pub struct PyTensorInfo { + /// Tensor name + #[pyo3(get)] + pub name: String, + /// Data type + #[pyo3(get)] + pub dtype: PyDtype, + /// Shape dimensions + #[pyo3(get)] + pub shape: Vec, + /// Byte offset within the data section + #[pyo3(get)] + pub offset: usize, + /// Total size in bytes + #[pyo3(get)] + pub size_bytes: usize, +} + +#[pymethods] +impl PyTensorInfo { + /// Total number of elements + #[getter] + fn numel(&self) -> usize { + self.shape.iter().product() + } + + fn __repr__(&self) -> String { + format!( + "TensorInfo(name='{}', dtype={:?}, shape={:?}, size_bytes={})", + self.name, self.dtype, self.shape, self.size_bytes + ) + } +} + +/// Memory-mapped SafeTensors file +#[pyclass(name = "SafeTensorsFile")] +pub struct PySafeTensorsFile { + inner: Arc, +} + +#[pymethods] +impl PySafeTensorsFile { + /// Open a safetensors file with memory mapping + #[new] + fn new(path: &str) -> PyResult { + let file = SafeTensorsFile::open(path).map_err(to_py_err)?; + Ok(PySafeTensorsFile { + inner: Arc::new(file), + }) + } + + /// Get list of all tensor names + #[getter] + fn tensor_names(&self) -> Vec { + self.inner.tensor_names().iter().map(|s| s.to_string()).collect() + } + + /// Get tensor info by name + fn tensor_info(&self, name: &str) -> PyResult { + let info = self.inner.tensor_info(name) + .ok_or_else(|| PyKeyError::new_err(name.to_string()))?; + Ok(PyTensorInfo { + name: info.name.clone(), + dtype: info.dtype.into(), + shape: info.shape.clone(), + offset: info.offset, + size_bytes: info.size_bytes, + }) + } + + /// Get tensor data as bytes + fn tensor_bytes(&self, name: &str) -> PyResult> { + let tensor = self.inner.tensor(name).map_err(to_py_err)?; + Ok(tensor.data.to_vec()) + } + + /// Get tensor as numpy array (only for Float32) + fn tensor_as_f32(&self, py: Python<'_>, name: &str) -> PyResult>> { + let tensor = self.inner.tensor(name).map_err(to_py_err)?; + let data = tensor.as_f32() + .ok_or_else(|| PyValueError::new_err("Tensor is not Float32"))?; + Ok(numpy::PyArray1::from_slice(py, data).into()) + } + + /// Total file size in bytes + #[getter] + fn file_size(&self) -> usize { + self.inner.file_size() + } + + /// Number of tensors in the file + #[getter] + fn num_tensors(&self) -> usize { + self.inner.num_tensors() + } + + fn __repr__(&self) -> String { + format!( + "SafeTensorsFile(num_tensors={}, file_size={})", + self.inner.num_tensors(), + self.inner.file_size() + ) + } + + fn __len__(&self) -> usize { + self.inner.num_tensors() + } + + fn __contains__(&self, name: &str) -> bool { + self.inner.tensor_info(name).is_some() + } +} + +/// Load a safetensors file +#[pyfunction] +fn load_safetensors(path: &str) -> PyResult { + PySafeTensorsFile::new(path) +} + +// ============================================================================ +// Tokenizer +// ============================================================================ + +/// Convert TokenizerError to PyErr +fn tokenizer_err_to_py(e: TokenizerError) -> PyErr { + match e { + TokenizerError::IoError(e) => PyIOError::new_err(e.to_string()), + TokenizerError::ParseError(e) => PyValueError::new_err(e), + TokenizerError::InvalidToken(t) => PyValueError::new_err(t), + } +} + +/// BPE Tokenizer for GPT-2 style models +#[pyclass(name = "Tokenizer")] +pub struct PyTokenizer { + inner: Tokenizer, +} + +#[pymethods] +impl PyTokenizer { + /// Load tokenizer from tokenizer.json file + #[new] + fn new(path: &str) -> PyResult { + let tokenizer = Tokenizer::from_file(path).map_err(tokenizer_err_to_py)?; + Ok(PyTokenizer { inner: tokenizer }) + } + + /// Load tokenizer from JSON string + #[staticmethod] + fn from_json(json: &str) -> PyResult { + let tokenizer = Tokenizer::from_json_str(json).map_err(tokenizer_err_to_py)?; + Ok(PyTokenizer { inner: tokenizer }) + } + + /// Get vocabulary size + #[getter] + fn vocab_size(&self) -> usize { + self.inner.vocab_size() + } + + /// Get BOS token ID if available + #[getter] + fn bos_token_id(&self) -> Option { + self.inner.bos_token_id() + } + + /// Get EOS token ID if available + #[getter] + fn eos_token_id(&self) -> Option { + self.inner.eos_token_id() + } + + /// Get PAD token ID if available + #[getter] + fn pad_token_id(&self) -> Option { + self.inner.pad_token_id() + } + + /// Encode text to token IDs + fn encode(&self, text: &str) -> Vec { + self.inner.encode(text) + } + + /// Decode token IDs to text + fn decode(&self, token_ids: Vec) -> String { + self.inner.decode(&token_ids) + } + + /// Get token string for an ID + fn id_to_token(&self, id: u32) -> Option { + self.inner.id_to_token(id).map(|s| s.to_string()) + } + + /// Get ID for a token string + fn token_to_id(&self, token: &str) -> Option { + self.inner.token_to_id(token) + } + + fn __repr__(&self) -> String { + format!("Tokenizer(vocab_size={})", self.inner.vocab_size()) + } + + fn __len__(&self) -> usize { + self.inner.vocab_size() + } +} + +/// Register the llm module +pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_function(wrap_pyfunction!(load_safetensors, m)?)?; + Ok(()) +} diff --git a/rust/pygpukit-python/src/scheduler.rs b/rust/pygpukit-python/src/scheduler.rs index 36db1b6..c0b7033 100644 --- a/rust/pygpukit-python/src/scheduler.rs +++ b/rust/pygpukit-python/src/scheduler.rs @@ -12,6 +12,8 @@ use pygpukit_core::scheduler::{ QosClass, QosPolicy, QosTaskMeta, QosEvaluation, QosPolicyEvaluator, QosStats, ResourceRequirements, PartitionManager, PartitionConfig, Partition, PartitionLimits, PartitionUsage, PartitionStats, + MultiLLMController, ContextState, ContextStats, ControllerStats, + FutureState, KernelFuture, KernelResult, AsyncKernelRequest, AsyncExecStats, }; /// Task state enum for Python @@ -1730,6 +1732,671 @@ impl PyPartitionManager { } } +// ============================================================================= +// Multi-LLM Controller Types +// ============================================================================= + +/// Context state enum for Python +#[pyclass(name = "ContextState", eq, eq_int)] +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum PyContextState { + Idle = 0, + Running = 1, + Paused = 2, +} + +impl From for PyContextState { + fn from(state: ContextState) -> Self { + match state { + ContextState::Idle => PyContextState::Idle, + ContextState::Running => PyContextState::Running, + ContextState::Paused => PyContextState::Paused, + } + } +} + +impl From for ContextState { + fn from(state: PyContextState) -> Self { + match state { + PyContextState::Idle => ContextState::Idle, + PyContextState::Running => ContextState::Running, + PyContextState::Paused => ContextState::Paused, + } + } +} + +/// Execution context statistics for Python +#[pyclass(name = "ContextStats")] +#[derive(Clone)] +pub struct PyContextStats { + inner: ContextStats, +} + +#[pymethods] +impl PyContextStats { + #[getter] + fn llm_id(&self) -> String { + self.inner.llm_id.clone() + } + + #[getter] + fn state(&self) -> PyContextState { + self.inner.state.into() + } + + #[getter] + fn stream_id(&self) -> u32 { + self.inner.stream_id + } + + #[getter] + fn max_vram(&self) -> usize { + self.inner.max_vram + } + + #[getter] + fn used_vram(&self) -> usize { + self.inner.used_vram + } + + #[getter] + fn available_vram(&self) -> usize { + self.inner.available_vram + } + + #[getter] + fn buffer_count(&self) -> usize { + self.inner.buffer_count + } + + fn __repr__(&self) -> String { + format!( + "ContextStats(llm_id='{}', state={:?}, stream={}, used_vram={})", + self.inner.llm_id, self.inner.state, self.inner.stream_id, self.inner.used_vram + ) + } +} + +/// Controller statistics for Python +#[pyclass(name = "ControllerStats")] +#[derive(Clone)] +pub struct PyControllerStats { + inner: ControllerStats, +} + +#[pymethods] +impl PyControllerStats { + #[getter] + fn initialized(&self) -> bool { + self.inner.initialized + } + + #[getter] + fn device_id(&self) -> i32 { + self.inner.device_id + } + + #[getter] + fn total_vram_budget(&self) -> usize { + self.inner.total_vram_budget + } + + #[getter] + fn device_total_memory(&self) -> usize { + self.inner.device_total_memory + } + + #[getter] + fn used_vram(&self) -> usize { + self.inner.used_vram + } + + #[getter] + fn available_vram(&self) -> usize { + self.inner.available_vram + } + + #[getter] + fn context_count(&self) -> usize { + self.inner.context_count + } + + #[getter] + fn stream_pool_size(&self) -> usize { + self.inner.stream_pool_size + } + + fn __repr__(&self) -> String { + format!( + "ControllerStats(initialized={}, contexts={}, used_vram={}, available_vram={})", + self.inner.initialized, self.inner.context_count, + self.inner.used_vram, self.inner.available_vram + ) + } +} + +/// Multi-LLM Dispatch Controller for Python +/// +/// Manages execution contexts for multiple LLM instances on a single GPU. +/// Uses stream-based isolation for concurrent execution. +/// +/// Example: +/// controller = MultiLLMController() +/// controller.initialize(0, 8 * GB, 8 * GB) +/// stream_id = controller.create_context("gpt2_a", 4 * GB) +/// controller.start_session() +/// # ... execute kernels ... +/// controller.end_session() +#[pyclass(name = "MultiLLMController")] +pub struct PyMultiLLMController { + inner: Arc, +} + +#[pymethods] +impl PyMultiLLMController { + /// Create a new controller (uninitialized) + #[new] + fn new() -> Self { + Self { + inner: Arc::new(MultiLLMController::new()), + } + } + + /// Initialize the controller + /// + /// Args: + /// device_id: CUDA device ID (default 0) + /// device_total_memory: Total device memory in bytes + /// total_vram_budget: VRAM budget for all contexts (0 = device total) + #[pyo3(signature = (device_id=0, device_total_memory=0, total_vram_budget=0))] + fn initialize(&self, device_id: i32, device_total_memory: usize, total_vram_budget: usize) { + // If device_total_memory is 0, use a sensible default (8GB) + let mem = if device_total_memory == 0 { 8 * 1024 * 1024 * 1024 } else { device_total_memory }; + self.inner.initialize(device_id, mem, total_vram_budget); + } + + /// Check if controller is initialized + fn is_initialized(&self) -> bool { + self.inner.is_initialized() + } + + /// Create an execution context for an LLM + /// + /// Args: + /// llm_id: Unique LLM identifier + /// max_vram: Maximum VRAM for this LLM (0 = share global budget) + /// + /// Returns: + /// The assigned stream ID for this context + #[pyo3(signature = (llm_id, max_vram=0))] + fn create_context(&self, llm_id: &str, max_vram: usize) -> PyResult { + self.inner.create_context(llm_id, max_vram) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e)) + } + + /// Get an existing context by LLM ID + fn get_context(&self, llm_id: &str) -> Option { + self.inner.get_context(llm_id).map(|s| PyContextStats { inner: s }) + } + + /// Destroy an execution context + fn destroy_context(&self, llm_id: &str) -> bool { + self.inner.destroy_context(llm_id) + } + + /// List all active context IDs + fn list_contexts(&self) -> Vec { + self.inner.list_contexts() + } + + /// Get number of active contexts + fn context_count(&self) -> usize { + self.inner.context_count() + } + + /// Get stream ID for a context + fn get_stream_id(&self, llm_id: &str) -> Option { + self.inner.get_stream_id(llm_id) + } + + /// Track a memory allocation for a context + fn track_allocation(&self, llm_id: &str, buffer_id: u64, size: usize) -> bool { + self.inner.track_allocation(llm_id, buffer_id, size) + } + + /// Track a memory deallocation for a context + fn track_deallocation(&self, llm_id: &str, buffer_id: u64) { + self.inner.track_deallocation(llm_id, buffer_id); + } + + /// Get total VRAM used across all contexts + fn used_vram(&self) -> usize { + self.inner.used_vram() + } + + /// Get available VRAM (global budget - used) + fn available_vram(&self) -> usize { + self.inner.available_vram() + } + + /// Start a session (mark all contexts as running) + fn start_session(&self) { + self.inner.start_session(); + } + + /// End a session (synchronize and mark all contexts as idle) + fn end_session(&self) { + self.inner.end_session(); + } + + /// Check if a session is active + fn is_session_active(&self) -> bool { + self.inner.is_session_active() + } + + /// Get controller statistics + fn stats(&self) -> PyControllerStats { + PyControllerStats { inner: self.inner.stats() } + } + + /// Reset the controller (destroy all contexts) + fn reset(&self) { + self.inner.reset(); + } + + // --- Async Execution --- + + /// Dispatch an async kernel for a specific LLM context + /// + /// Args: + /// llm_id: LLM identifier + /// request: Kernel dispatch request + /// + /// Returns: + /// KernelFuture for tracking execution + /// + /// Example: + /// request = AsyncKernelRequest.linear(kernel_handle, 1024, 256) + /// future = controller.dispatch_async("llm", request) + /// # Do other work... + /// result = future.wait() + fn dispatch_async(&self, llm_id: &str, request: PyAsyncKernelRequest) -> PyResult { + self.inner.dispatch_async(llm_id, request.inner) + .map(|f| PyKernelFuture { inner: f }) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e)) + } + + /// Get pending futures for a context + fn get_pending_futures(&self, llm_id: &str) -> Option> { + self.inner.get_pending_futures(llm_id) + } + + /// Mark a future as launched + fn mark_future_launched(&self, llm_id: &str, future_id: u64) { + self.inner.mark_future_launched(llm_id, future_id); + } + + /// Mark a future as completed + fn mark_future_completed(&self, llm_id: &str, future_id: u64, exec_time: f64) { + self.inner.mark_future_completed(llm_id, future_id, exec_time); + } + + /// Mark a future as failed + fn mark_future_failed(&self, llm_id: &str, future_id: u64, error: String) { + self.inner.mark_future_failed(llm_id, future_id, error); + } + + /// Cancel a pending future + fn cancel_future(&self, llm_id: &str, future_id: u64) -> bool { + self.inner.cancel_future(llm_id, future_id) + } + + /// Get a future by ID from a context + fn get_future(&self, llm_id: &str, future_id: u64) -> Option { + self.inner.get_future(llm_id, future_id).map(|f| PyKernelFuture { inner: f }) + } + + /// Get async execution stats for a context + fn async_stats(&self, llm_id: &str) -> Option { + self.inner.async_stats(llm_id).map(|s| PyAsyncExecStats { inner: s }) + } + + // --- Per-Context Session Management --- + + /// Start a session for a specific context + /// + /// Unlike global session(), per-context sessions allow independent + /// LLM execution. Each context can have its own session lifecycle. + /// + /// Example: + /// # TTS and LLM run independently + /// controller.start_context_session("tts") + /// controller.start_context_session("llm") + /// + /// # Dispatch async work + /// tts_future = controller.dispatch_async("tts", tts_request) + /// llm_future = controller.dispatch_async("llm", llm_request) + /// + /// # Wait for results in any order + /// llm_result = llm_future.wait() # Get LLM first + /// tts_result = tts_future.wait() # Then TTS + fn start_context_session(&self, llm_id: &str) -> bool { + self.inner.start_context_session(llm_id) + } + + /// End a session for a specific context + fn end_context_session(&self, llm_id: &str) -> bool { + self.inner.end_context_session(llm_id) + } + + /// Check if a specific context has an active session + fn is_context_session_active(&self, llm_id: &str) -> Option { + self.inner.is_context_session_active(llm_id) + } + + fn __repr__(&self) -> String { + let stats = self.inner.stats(); + format!( + "MultiLLMController(initialized={}, contexts={}, used_vram={})", + stats.initialized, stats.context_count, stats.used_vram + ) + } +} + +// ============================================================================= +// Async Execution Types +// ============================================================================= + +/// Future state enum for Python +#[pyclass(name = "FutureState", eq, eq_int)] +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum PyFutureState { + Pending = 0, + Running = 1, + Completed = 2, + Failed = 3, + Cancelled = 4, +} + +impl From for PyFutureState { + fn from(state: FutureState) -> Self { + match state { + FutureState::Pending => PyFutureState::Pending, + FutureState::Running => PyFutureState::Running, + FutureState::Completed => PyFutureState::Completed, + FutureState::Failed => PyFutureState::Failed, + FutureState::Cancelled => PyFutureState::Cancelled, + } + } +} + +/// Kernel execution result for Python +#[pyclass(name = "KernelResult")] +#[derive(Clone)] +pub struct PyKernelResult { + inner: KernelResult, +} + +#[pymethods] +impl PyKernelResult { + /// Whether execution succeeded + #[getter] + fn success(&self) -> bool { + self.inner.success + } + + /// Error message if failed + #[getter] + fn error(&self) -> Option { + self.inner.error.clone() + } + + /// Execution time in seconds + #[getter] + fn exec_time(&self) -> f64 { + self.inner.exec_time + } + + /// Output data as bytes (if any) + #[getter] + fn output(&self) -> Option> { + self.inner.output.clone() + } + + fn __repr__(&self) -> String { + if self.inner.success { + format!("KernelResult(success=True, exec_time={:.4}s)", self.inner.exec_time) + } else { + format!("KernelResult(success=False, error='{}')", self.inner.error.as_deref().unwrap_or("unknown")) + } + } +} + +/// Async kernel request for Python +/// +/// Use this to specify kernel dispatch parameters. +#[pyclass(name = "AsyncKernelRequest")] +#[derive(Clone)] +pub struct PyAsyncKernelRequest { + inner: AsyncKernelRequest, +} + +#[pymethods] +impl PyAsyncKernelRequest { + /// Create a new async kernel request + /// + /// Args: + /// kernel_handle: Kernel function handle (CUfunction as int) + #[new] + fn new(kernel_handle: u64) -> Self { + Self { + inner: AsyncKernelRequest::new(kernel_handle), + } + } + + /// Create a linear kernel request (1D grid) + /// + /// Args: + /// kernel_handle: Kernel function handle + /// n_elements: Number of elements to process + /// block_size: Threads per block (default 256) + #[staticmethod] + #[pyo3(signature = (kernel_handle, n_elements, block_size=256))] + fn linear(kernel_handle: u64, n_elements: usize, block_size: u32) -> Self { + Self { + inner: AsyncKernelRequest::linear(kernel_handle, n_elements, block_size), + } + } + + /// Set grid dimensions + fn with_grid(&self, x: u32, y: u32, z: u32) -> Self { + Self { + inner: self.inner.clone().with_grid(x, y, z), + } + } + + /// Set block dimensions + fn with_block(&self, x: u32, y: u32, z: u32) -> Self { + Self { + inner: self.inner.clone().with_block(x, y, z), + } + } + + /// Set shared memory size + fn with_shared_mem(&self, bytes: u32) -> Self { + Self { + inner: self.inner.clone().with_shared_mem(bytes), + } + } + + /// Set kernel arguments (as list of u64 pointers) + fn with_args(&self, args: Vec) -> Self { + Self { + inner: self.inner.clone().with_args(args), + } + } + + #[getter] + fn kernel_handle(&self) -> u64 { + self.inner.kernel_handle + } + + #[getter] + fn grid(&self) -> (u32, u32, u32) { + self.inner.grid + } + + #[getter] + fn block(&self) -> (u32, u32, u32) { + self.inner.block + } + + #[getter] + fn shared_mem(&self) -> u32 { + self.inner.shared_mem + } + + fn __repr__(&self) -> String { + format!( + "AsyncKernelRequest(handle=0x{:x}, grid={:?}, block={:?})", + self.inner.kernel_handle, self.inner.grid, self.inner.block + ) + } +} + +/// Async execution statistics for Python +#[pyclass(name = "AsyncExecStats")] +#[derive(Clone)] +pub struct PyAsyncExecStats { + inner: AsyncExecStats, +} + +#[pymethods] +impl PyAsyncExecStats { + #[getter] + fn total_dispatched(&self) -> u64 { + self.inner.total_dispatched + } + + #[getter] + fn pending_count(&self) -> usize { + self.inner.pending_count + } + + #[getter] + fn running_count(&self) -> usize { + self.inner.running_count + } + + #[getter] + fn completed_count(&self) -> u64 { + self.inner.completed_count + } + + #[getter] + fn failed_count(&self) -> u64 { + self.inner.failed_count + } + + #[getter] + fn cancelled_count(&self) -> u64 { + self.inner.cancelled_count + } + + #[getter] + fn avg_exec_time(&self) -> f64 { + self.inner.avg_exec_time + } + + fn __repr__(&self) -> String { + format!( + "AsyncExecStats(dispatched={}, pending={}, running={}, completed={})", + self.inner.total_dispatched, self.inner.pending_count, + self.inner.running_count, self.inner.completed_count + ) + } +} + +/// Kernel future for Python +/// +/// Handle for tracking async kernel execution. Use `wait()` to block +/// until completion or `is_ready()` to check without blocking. +/// +/// Example: +/// request = AsyncKernelRequest(kernel_handle) +/// future = controller.dispatch_async("llm", request) +/// +/// # Do other work while kernel executes... +/// +/// if future.is_ready(): +/// result = future.wait() +#[pyclass(name = "KernelFuture")] +#[derive(Clone)] +pub struct PyKernelFuture { + inner: KernelFuture, +} + +#[pymethods] +impl PyKernelFuture { + /// Get future ID + #[getter] + fn id(&self) -> u64 { + self.inner.id() + } + + /// Get stream ID where kernel is executing + #[getter] + fn stream_id(&self) -> u32 { + self.inner.stream_id() + } + + /// Get context ID (LLM ID) + #[getter] + fn context_id(&self) -> String { + self.inner.context_id().to_string() + } + + /// Get current state + #[getter] + fn state(&self) -> PyFutureState { + self.inner.state().into() + } + + /// Check if kernel execution is complete (non-blocking) + fn is_ready(&self) -> bool { + self.inner.is_ready() + } + + /// Wait for kernel completion (blocking) + /// + /// Returns the kernel result. If already complete, returns immediately. + /// If still running, blocks until completion. + fn wait(&self) -> PyKernelResult { + PyKernelResult { + inner: self.inner.wait(), + } + } + + /// Try to get result without blocking + /// + /// Returns None if not ready yet. + fn try_get(&self) -> Option { + self.inner.try_get().map(|r| PyKernelResult { inner: r }) + } + + /// Get execution time (if completed) + fn exec_time(&self) -> Option { + self.inner.exec_time() + } + + fn __repr__(&self) -> String { + format!( + "KernelFuture(id={}, context='{}', state={:?}, ready={})", + self.inner.id(), self.inner.context_id(), self.inner.state(), self.inner.is_ready() + ) + } +} + /// Register scheduler module pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; @@ -1760,5 +2427,16 @@ pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + // Multi-LLM Controller + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + // Async Execution + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/src/pygpukit/__init__.py b/src/pygpukit/__init__.py index f667a9c..2f84a8d 100644 --- a/src/pygpukit/__init__.py +++ b/src/pygpukit/__init__.py @@ -1,7 +1,9 @@ """PyGPUkit - A lightweight GPU runtime for Python.""" -__version__ = "0.2.5" +__version__ = "0.2.6" +# LLM support (safetensors loader) +from pygpukit import llm from pygpukit.core.array import GPUArray from pygpukit.core.device import ( DeviceInfo, @@ -27,7 +29,21 @@ jit, warmup, ) -from pygpukit.ops.basic import add, div, exp, log, matmul, max, mean, mul, relu, sub, sum +from pygpukit.ops.basic import ( + add, + div, + exp, + gelu, + layernorm, + log, + matmul, + max, + mean, + mul, + relu, + sub, + sum, +) # Try to import Rust types, fallback to Python implementations try: @@ -87,9 +103,13 @@ "exp", "log", "relu", + "gelu", + "layernorm", "matmul", # Reductions "sum", "mean", "max", + # LLM support + "llm", ] diff --git a/src/pygpukit/core/backend.py b/src/pygpukit/core/backend.py index 63e50d5..a22b271 100644 --- a/src/pygpukit/core/backend.py +++ b/src/pygpukit/core/backend.py @@ -480,3 +480,35 @@ def reset_backend() -> None: """Reset the backend to auto-detection.""" global _backend _backend = None + + +# Rust module (PyO3 bindings) +_rust_module: Any = None +_rust_import_attempted: bool = False + + +def get_rust_module() -> Any | None: + """Get the Rust module (PyO3 bindings) if available. + + Returns: + The _pygpukit_rust module if available, None otherwise. + """ + global _rust_module, _rust_import_attempted + + if _rust_import_attempted: + return _rust_module + + _rust_import_attempted = True + try: + from pygpukit import _pygpukit_rust # type: ignore[attr-defined] + + _rust_module = _pygpukit_rust + except ImportError: + _rust_module = None + + return _rust_module + + +def has_rust_module() -> bool: + """Check if the Rust module (PyO3 bindings) is available.""" + return get_rust_module() is not None diff --git a/src/pygpukit/llm/__init__.py b/src/pygpukit/llm/__init__.py new file mode 100644 index 0000000..92e603d --- /dev/null +++ b/src/pygpukit/llm/__init__.py @@ -0,0 +1,358 @@ +"""LLM support module for PyGPUkit. + +Provides: +- SafeTensors file loading with memory mapping +- Tensor metadata and data access +- GPU tensor allocation helpers +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..core.backend import get_rust_module + +if TYPE_CHECKING: + from collections.abc import Sequence + +# Get the Rust llm module +_rust = get_rust_module() +_llm = _rust.llm if _rust else None + + +class Dtype: + """Tensor data type enumeration.""" + + Float32 = 0 + Float16 = 1 + BFloat16 = 2 + Float64 = 3 + Int32 = 4 + Int64 = 5 + Int16 = 6 + Int8 = 7 + UInt8 = 8 + Bool = 9 + + _NAMES = { + 0: "float32", + 1: "float16", + 2: "bfloat16", + 3: "float64", + 4: "int32", + 5: "int64", + 6: "int16", + 7: "int8", + 8: "uint8", + 9: "bool", + } + + _SIZES = { + 0: 4, # float32 + 1: 2, # float16 + 2: 2, # bfloat16 + 3: 8, # float64 + 4: 4, # int32 + 5: 8, # int64 + 6: 2, # int16 + 7: 1, # int8 + 8: 1, # uint8 + 9: 1, # bool + } + + @classmethod + def element_size(cls, dtype: int) -> int: + """Get the size in bytes of a single element.""" + return cls._SIZES.get(dtype, 0) + + @classmethod + def name(cls, dtype: int) -> str: + """Get the string name of a dtype.""" + return cls._NAMES.get(dtype, "unknown") + + +class TensorInfo: + """Metadata for a single tensor in a safetensors file.""" + + def __init__( + self, + name: str, + dtype: int, + shape: Sequence[int], + offset: int, + size_bytes: int, + ): + self.name = name + self.dtype = dtype + self.shape = list(shape) + self.offset = offset + self.size_bytes = size_bytes + + @property + def numel(self) -> int: + """Total number of elements.""" + result = 1 + for dim in self.shape: + result *= dim + return result + + @property + def dtype_name(self) -> str: + """String name of the dtype.""" + return Dtype.name(self.dtype) + + def __repr__(self) -> str: + return ( + f"TensorInfo(name='{self.name}', dtype={self.dtype_name}, " + f"shape={self.shape}, size_bytes={self.size_bytes})" + ) + + +class SafeTensorsFile: + """Memory-mapped SafeTensors file. + + Provides efficient access to tensor metadata and data from a .safetensors file + using memory mapping for zero-copy data access. + + Example: + >>> st = SafeTensorsFile("model.safetensors") + >>> print(st.tensor_names) + ['weight', 'bias'] + >>> info = st.tensor_info('weight') + >>> print(info.shape, info.dtype_name) + [768, 768] float16 + >>> data = st.tensor_bytes('weight') + """ + + def __init__(self, path: str): + """Open a safetensors file. + + Args: + path: Path to the .safetensors file + """ + if _llm is None: + raise RuntimeError("Rust LLM module not available") + self._inner = _llm.SafeTensorsFile(path) + + @property + def tensor_names(self) -> list[str]: + """Get list of all tensor names.""" + return self._inner.tensor_names + + @property + def file_size(self) -> int: + """Total file size in bytes.""" + return self._inner.file_size + + @property + def num_tensors(self) -> int: + """Number of tensors in the file.""" + return self._inner.num_tensors + + def tensor_info(self, name: str) -> TensorInfo: + """Get metadata for a tensor by name. + + Args: + name: Tensor name + + Returns: + TensorInfo with dtype, shape, offset, and size + + Raises: + KeyError: If tensor name not found + """ + info = self._inner.tensor_info(name) + return TensorInfo( + name=info.name, + dtype=int(info.dtype), + shape=info.shape, + offset=info.offset, + size_bytes=info.size_bytes, + ) + + def tensor_bytes(self, name: str) -> bytes: + """Get raw tensor data as bytes. + + Args: + name: Tensor name + + Returns: + Raw bytes of the tensor data + + Raises: + KeyError: If tensor name not found + """ + return bytes(self._inner.tensor_bytes(name)) + + def tensor_as_f32(self, name: str): + """Get tensor data as numpy float32 array. + + Args: + name: Tensor name + + Returns: + 1D numpy array of float32 values + + Raises: + KeyError: If tensor name not found + ValueError: If tensor dtype is not Float32 + """ + return self._inner.tensor_as_f32(name) + + def __len__(self) -> int: + return self.num_tensors + + def __contains__(self, name: str) -> bool: + return name in self._inner + + def __repr__(self) -> str: + return f"SafeTensorsFile(num_tensors={self.num_tensors}, file_size={self.file_size})" + + +def load_safetensors(path: str) -> SafeTensorsFile: + """Load a safetensors file. + + Args: + path: Path to the .safetensors file + + Returns: + SafeTensorsFile object for accessing tensor data + """ + return SafeTensorsFile(path) + + +class Tokenizer: + """BPE Tokenizer for GPT-2 style models. + + Loads tokenizer.json format and provides basic encode/decode functionality. + + Example: + >>> tok = Tokenizer("tokenizer.json") + >>> ids = tok.encode("Hello, world!") + >>> text = tok.decode(ids) + """ + + def __init__(self, path: str): + """Load tokenizer from tokenizer.json file. + + Args: + path: Path to the tokenizer.json file + """ + if _llm is None: + raise RuntimeError("Rust LLM module not available") + self._inner = _llm.Tokenizer(path) + + @classmethod + def from_json(cls, json_str: str) -> Tokenizer: + """Load tokenizer from JSON string. + + Args: + json_str: JSON string containing tokenizer config + + Returns: + Tokenizer instance + """ + if _llm is None: + raise RuntimeError("Rust LLM module not available") + instance = cls.__new__(cls) + instance._inner = _llm.Tokenizer.from_json(json_str) + return instance + + @property + def vocab_size(self) -> int: + """Get vocabulary size.""" + return self._inner.vocab_size + + @property + def bos_token_id(self) -> int | None: + """Get BOS (beginning of sequence) token ID if available.""" + return self._inner.bos_token_id + + @property + def eos_token_id(self) -> int | None: + """Get EOS (end of sequence) token ID if available.""" + return self._inner.eos_token_id + + @property + def pad_token_id(self) -> int | None: + """Get PAD token ID if available.""" + return self._inner.pad_token_id + + def encode(self, text: str) -> list[int]: + """Encode text to token IDs. + + Args: + text: Input text to encode + + Returns: + List of token IDs + """ + return list(self._inner.encode(text)) + + def decode(self, token_ids: list[int]) -> str: + """Decode token IDs to text. + + Args: + token_ids: List of token IDs + + Returns: + Decoded text string + """ + return self._inner.decode(token_ids) + + def id_to_token(self, token_id: int) -> str | None: + """Get token string for an ID. + + Args: + token_id: Token ID + + Returns: + Token string if ID is valid, None otherwise + """ + return self._inner.id_to_token(token_id) + + def token_to_id(self, token: str) -> int | None: + """Get ID for a token string. + + Args: + token: Token string + + Returns: + Token ID if token exists, None otherwise + """ + return self._inner.token_to_id(token) + + def __len__(self) -> int: + return self.vocab_size + + def __repr__(self) -> str: + return f"Tokenizer(vocab_size={self.vocab_size})" + + +from pygpukit.llm.model import ( # noqa: E402 + MLP, + GPT2Config, + GPT2Model, + LayerNorm, + Linear, + TransformerBlock, + load_gpt2_from_safetensors, +) + +__all__ = [ + # SafeTensors + "Dtype", + "TensorInfo", + "SafeTensorsFile", + "load_safetensors", + # Tokenizer + "Tokenizer", + # Model components + "GPT2Config", + "GPT2Model", + "LayerNorm", + "Linear", + "MLP", + "TransformerBlock", + "load_gpt2_from_safetensors", +] diff --git a/src/pygpukit/llm/model.py b/src/pygpukit/llm/model.py new file mode 100644 index 0000000..77443b4 --- /dev/null +++ b/src/pygpukit/llm/model.py @@ -0,0 +1,436 @@ +"""LLM model components for PyGPUkit. + +Provides transformer building blocks for GPT-2 style models: +- MLP block (fc1 -> gelu -> fc2) +- TransformerBlock (ln -> mlp -> residual) +- GPT2Model (embedding -> blocks -> lm_head) +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from pygpukit.core.array import GPUArray +from pygpukit.core.factory import from_numpy +from pygpukit.ops.basic import add, gelu, layernorm, matmul + +if TYPE_CHECKING: + pass + + +@dataclass +class GPT2Config: + """Configuration for GPT-2 model. + + GPT-2 Small defaults: + vocab_size=50257, n_embd=768, n_layer=12, n_head=12 + """ + + vocab_size: int = 50257 + n_embd: int = 768 + n_layer: int = 12 + n_head: int = 12 + n_positions: int = 1024 + layer_norm_eps: float = 1e-5 + + @property + def n_inner(self) -> int: + """Inner dimension of MLP (4 * n_embd).""" + return 4 * self.n_embd + + +class Linear: + """Linear layer: y = xW^T + b + + For MVP, we store weight as [out_features, in_features] and transpose + during forward pass using simple element access. + """ + + def __init__( + self, + weight: GPUArray, + bias: GPUArray | None = None, + ): + """Initialize Linear layer. + + Args: + weight: Weight matrix [out_features, in_features] + bias: Optional bias vector [out_features] + """ + if weight.ndim != 2: + raise ValueError(f"weight must be 2D, got {weight.ndim}D") + self.weight = weight # [out_features, in_features] + self.bias = bias + self.out_features = weight.shape[0] + self.in_features = weight.shape[1] + + def __call__(self, x: GPUArray) -> GPUArray: + """Forward pass: y = xW^T + b + + Args: + x: Input tensor [batch, in_features] + + Returns: + Output tensor [batch, out_features] + """ + if x.ndim != 2: + raise ValueError(f"input must be 2D [batch, in_features], got {x.ndim}D") + if x.shape[1] != self.in_features: + raise ValueError( + f"input features {x.shape[1]} doesn't match weight {self.in_features}" + ) + + # For MVP: transpose weight and use matmul + # x: [batch, in_features] + # weight: [out_features, in_features] + # We need: y = x @ weight.T = x @ [in_features, out_features] + + # Simple approach: transpose weight to CPU, create new GPU array + # This is not optimal but works for MVP + weight_np = self.weight.to_numpy() + weight_t = from_numpy(weight_np.T.copy()) # [in_features, out_features] + + # y = x @ weight_t: [batch, in_features] @ [in_features, out_features] = [batch, out_features] + y = matmul(x, weight_t) + + if self.bias is not None: + # Add bias: y[i, j] += bias[j] + # For now, do this on CPU as we don't have broadcast add + y_np = y.to_numpy() + bias_np = self.bias.to_numpy() + y_np += bias_np + y = from_numpy(y_np) + + return y + + +class MLP: + """MLP block for GPT-2. + + Structure: fc1 -> gelu -> fc2 + fc1: [n_embd] -> [n_inner] + fc2: [n_inner] -> [n_embd] + """ + + def __init__( + self, + c_fc_weight: GPUArray, + c_fc_bias: GPUArray | None, + c_proj_weight: GPUArray, + c_proj_bias: GPUArray | None, + ): + """Initialize MLP block. + + Args: + c_fc_weight: First linear weight [n_inner, n_embd] + c_fc_bias: First linear bias [n_inner] + c_proj_weight: Second linear weight [n_embd, n_inner] + c_proj_bias: Second linear bias [n_embd] + """ + self.c_fc = Linear(c_fc_weight, c_fc_bias) + self.c_proj = Linear(c_proj_weight, c_proj_bias) + + def __call__(self, x: GPUArray) -> GPUArray: + """Forward pass: fc1 -> gelu -> fc2 + + Args: + x: Input tensor [batch, n_embd] + + Returns: + Output tensor [batch, n_embd] + """ + h = self.c_fc(x) + h = gelu(h) + h = self.c_proj(h) + return h + + +class LayerNorm: + """Layer normalization with learnable parameters.""" + + def __init__( + self, + weight: GPUArray, + bias: GPUArray, + eps: float = 1e-5, + ): + """Initialize LayerNorm. + + Args: + weight: Scale parameter (gamma) [features] + bias: Shift parameter (beta) [features] + eps: Epsilon for numerical stability + """ + self.weight = weight + self.bias = bias + self.eps = eps + + def __call__(self, x: GPUArray) -> GPUArray: + """Forward pass. + + Args: + x: Input tensor [batch, features] + + Returns: + Normalized tensor [batch, features] + """ + return layernorm(x, self.weight, self.bias, self.eps) + + +class TransformerBlock: + """Transformer block (MLP only, no attention for MVP). + + Structure: ln -> mlp -> residual + """ + + def __init__( + self, + ln_weight: GPUArray, + ln_bias: GPUArray, + mlp: MLP, + eps: float = 1e-5, + ): + """Initialize TransformerBlock. + + Args: + ln_weight: LayerNorm weight [n_embd] + ln_bias: LayerNorm bias [n_embd] + mlp: MLP block + eps: LayerNorm epsilon + """ + self.ln = LayerNorm(ln_weight, ln_bias, eps) + self.mlp = mlp + + def __call__(self, x: GPUArray) -> GPUArray: + """Forward pass: ln -> mlp -> residual + + Args: + x: Input tensor [batch, n_embd] + + Returns: + Output tensor [batch, n_embd] + """ + # LayerNorm + h = self.ln(x) + # MLP + h = self.mlp(h) + # Residual + return add(x, h) + + +class GPT2Model: + """GPT-2 model (MLP only, no attention for MVP). + + Structure: + - Token embedding + - Position embedding + - Transformer blocks (MLP only) + - Final LayerNorm + - LM head (tied to embedding) + """ + + def __init__( + self, + config: GPT2Config, + wte: GPUArray, # Token embedding [vocab_size, n_embd] + wpe: GPUArray, # Position embedding [n_positions, n_embd] + blocks: list[TransformerBlock], + ln_f_weight: GPUArray, + ln_f_bias: GPUArray, + ): + """Initialize GPT-2 model. + + Args: + config: Model configuration + wte: Token embedding weights [vocab_size, n_embd] + wpe: Position embedding weights [n_positions, n_embd] + blocks: List of transformer blocks + ln_f_weight: Final LayerNorm weight + ln_f_bias: Final LayerNorm bias + """ + self.config = config + self.wte = wte + self.wpe = wpe + self.blocks = blocks + self.ln_f = LayerNorm(ln_f_weight, ln_f_bias, config.layer_norm_eps) + + def __call__(self, input_ids: list[int], position_ids: list[int] | None = None) -> GPUArray: + """Forward pass. + + Args: + input_ids: Token IDs [seq_len] + position_ids: Optional position IDs [seq_len] + + Returns: + Hidden states [seq_len, n_embd] + """ + import numpy as np + + seq_len = len(input_ids) + + if position_ids is None: + position_ids = list(range(seq_len)) + + # Get embeddings by indexing (CPU for MVP) + wte_np = self.wte.to_numpy() + wpe_np = self.wpe.to_numpy() + + # Token embeddings: select rows from wte + token_embeds = wte_np[input_ids] # [seq_len, n_embd] + + # Position embeddings: select rows from wpe + pos_embeds = wpe_np[position_ids] # [seq_len, n_embd] + + # Combine embeddings + hidden = from_numpy((token_embeds + pos_embeds).astype(np.float32)) + + # Apply transformer blocks + for block in self.blocks: + hidden = block(hidden) + + # Final LayerNorm + hidden = self.ln_f(hidden) + + return hidden + + def lm_head(self, hidden: GPUArray) -> GPUArray: + """Compute logits from hidden states. + + Args: + hidden: Hidden states [seq_len, n_embd] + + Returns: + Logits [seq_len, vocab_size] + """ + # LM head is tied to embedding weights + # logits = hidden @ wte.T + wte_np = self.wte.to_numpy() + hidden_np = hidden.to_numpy() + logits = hidden_np @ wte_np.T + return from_numpy(logits.astype(hidden_np.dtype)) + + def generate( + self, + input_ids: list[int], + max_new_tokens: int = 20, + temperature: float = 1.0, + ) -> list[int]: + """Generate tokens autoregressively. + + Args: + input_ids: Initial token IDs + max_new_tokens: Maximum number of new tokens to generate + temperature: Sampling temperature (1.0 = greedy argmax) + + Returns: + List of all token IDs (input + generated) + """ + import numpy as np + + tokens = list(input_ids) + + for _ in range(max_new_tokens): + # Truncate to max context length + context = tokens[-self.config.n_positions:] + + # Forward pass + hidden = self(context) + + # Get logits for last position + logits = self.lm_head(hidden) + logits_np = logits.to_numpy() + last_logits = logits_np[-1] # [vocab_size] + + # Apply temperature + if temperature != 1.0: + last_logits = last_logits / temperature + + # Greedy decoding (argmax) + next_token = int(np.argmax(last_logits)) + tokens.append(next_token) + + # Stop at EOS (50256 for GPT-2) + if next_token == 50256: + break + + return tokens + + +def load_gpt2_from_safetensors( + model_path: str, + config: GPT2Config | None = None, +) -> GPT2Model: + """Load GPT-2 model from safetensors file. + + Note: This is an MVP that only loads MLP weights (no attention). + The model will not produce coherent text without attention. + + Args: + model_path: Path to model.safetensors file + config: Model configuration (defaults to GPT-2 small) + + Returns: + GPT2Model instance + """ + from pygpukit.llm import SafeTensorsFile + + if config is None: + config = GPT2Config() + + st = SafeTensorsFile(model_path) + + # Helper to load tensor + def load_tensor(name: str) -> GPUArray: + data = st.tensor_bytes(name) + info = st.tensor_info(name) + + import numpy as np + + # Determine numpy dtype + dtype_map = { + 0: np.float32, # Float32 + 1: np.float16, # Float16 + 2: np.float32, # BFloat16 -> convert to float32 for now + 3: np.float64, # Float64 + } + np_dtype = dtype_map.get(info.dtype, np.float32) + + # Create numpy array from bytes + arr = np.frombuffer(data, dtype=np_dtype).reshape(info.shape) + return from_numpy(arr.copy()) + + # Load embeddings + wte = load_tensor("wte.weight") + wpe = load_tensor("wpe.weight") + + # Load blocks + blocks = [] + for i in range(config.n_layer): + prefix = f"h.{i}." + + # Check if MLP weights exist + mlp_c_fc_w_name = f"{prefix}mlp.c_fc.weight" + if mlp_c_fc_w_name not in st.tensor_names: + # Skip blocks without MLP (shouldn't happen for GPT-2) + continue + + # LayerNorm 2 (before MLP in GPT-2) + ln_2_w = load_tensor(f"{prefix}ln_2.weight") + ln_2_b = load_tensor(f"{prefix}ln_2.bias") + + # MLP + mlp_c_fc_w = load_tensor(f"{prefix}mlp.c_fc.weight") + mlp_c_fc_b = load_tensor(f"{prefix}mlp.c_fc.bias") + mlp_c_proj_w = load_tensor(f"{prefix}mlp.c_proj.weight") + mlp_c_proj_b = load_tensor(f"{prefix}mlp.c_proj.bias") + + mlp = MLP(mlp_c_fc_w, mlp_c_fc_b, mlp_c_proj_w, mlp_c_proj_b) + block = TransformerBlock(ln_2_w, ln_2_b, mlp, config.layer_norm_eps) + blocks.append(block) + + # Final LayerNorm + ln_f_w = load_tensor("ln_f.weight") + ln_f_b = load_tensor("ln_f.bias") + + return GPT2Model(config, wte, wpe, blocks, ln_f_w, ln_f_b) diff --git a/src/pygpukit/ops/__init__.py b/src/pygpukit/ops/__init__.py index c75125e..d58870e 100644 --- a/src/pygpukit/ops/__init__.py +++ b/src/pygpukit/ops/__init__.py @@ -1,5 +1,33 @@ """Operations module for PyGPUkit.""" -from pygpukit.ops.basic import add, div, exp, log, matmul, max, mean, mul, relu, sub, sum +from pygpukit.ops.basic import ( + add, + div, + exp, + gelu, + layernorm, + log, + matmul, + max, + mean, + mul, + relu, + sub, + sum, +) -__all__ = ["add", "sub", "mul", "div", "exp", "log", "relu", "matmul", "sum", "mean", "max"] +__all__ = [ + "add", + "sub", + "mul", + "div", + "exp", + "log", + "relu", + "gelu", + "layernorm", + "matmul", + "sum", + "mean", + "max", +] diff --git a/src/pygpukit/ops/basic.py b/src/pygpukit/ops/basic.py index 49272a7..75a8592 100644 --- a/src/pygpukit/ops/basic.py +++ b/src/pygpukit/ops/basic.py @@ -537,3 +537,138 @@ def _max_native(a: GPUArray) -> GPUArray: a_native = a._get_native() c_native = native.max(a_native) return GPUArray._wrap_native(c_native) + + +# ============================================================================ +# Neural Network Operations +# ============================================================================ + + +def gelu(a: GPUArray) -> GPUArray: + """GELU (Gaussian Error Linear Unit) activation. + + Computes: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + + Args: + a: Input array (float32, float64, float16, or bfloat16). + + Returns: + A new GPUArray containing gelu(a). + + Raises: + ValueError: If dtype is not a float type. + """ + _validate_float_dtype(a, "gelu") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _gelu_native(a) + else: + return _gelu_cpu(a) + + +def _gelu_cpu(a: GPUArray) -> GPUArray: + """CPU implementation of gelu.""" + a_np = a.to_numpy() + # GELU approximation + x = a_np.astype(np.float32) if a_np.dtype in [np.float16] else a_np + c1 = 0.7978845608 # sqrt(2/pi) + c2 = 0.044715 + result = x * 0.5 * (1 + np.tanh(c1 * (x + c2 * x**3))) + return from_numpy(result.astype(a_np.dtype)) + + +def _gelu_native(a: GPUArray) -> GPUArray: + """Native C++ CUDA implementation of gelu (zero-copy).""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + a_native = a._get_native() + c_native = native.gelu(a_native) + return GPUArray._wrap_native(c_native) + + +def layernorm( + input: GPUArray, + gamma: GPUArray, + beta: GPUArray, + eps: float = 1e-5, +) -> GPUArray: + """Layer normalization. + + Computes: (x - mean) / sqrt(var + eps) * gamma + beta + + Args: + input: Input array of shape [batch, features]. + gamma: Scale parameter of shape [features]. + beta: Bias parameter of shape [features]. + eps: Small epsilon for numerical stability. + + Returns: + A new GPUArray containing the normalized output. + + Raises: + ValueError: If shapes or dtypes don't match. + """ + _validate_float_dtype(input, "layernorm") + + if input.ndim != 2: + raise ValueError(f"layernorm expects 2D input [batch, features], got {input.ndim}D") + if gamma.ndim != 1 or beta.ndim != 1: + raise ValueError("layernorm expects 1D gamma and beta") + if input.dtype != gamma.dtype or input.dtype != beta.dtype: + raise ValueError("layernorm: all inputs must have same dtype") + + features = input.shape[1] + if gamma.shape[0] != features or beta.shape[0] != features: + raise ValueError( + f"layernorm: gamma/beta size {gamma.shape[0]} must match features {features}" + ) + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _layernorm_native(input, gamma, beta, eps) + else: + return _layernorm_cpu(input, gamma, beta, eps) + + +def _layernorm_cpu( + input: GPUArray, + gamma: GPUArray, + beta: GPUArray, + eps: float, +) -> GPUArray: + """CPU implementation of layernorm.""" + x = input.to_numpy() + g = gamma.to_numpy() + b = beta.to_numpy() + + # Compute mean and variance along features axis + mean = x.mean(axis=1, keepdims=True) + var = x.var(axis=1, keepdims=True) + + # Normalize + normalized = (x - mean) / np.sqrt(var + eps) + + # Apply affine transform + result = normalized * g + b + return from_numpy(result) + + +def _layernorm_native( + input: GPUArray, + gamma: GPUArray, + beta: GPUArray, + eps: float, +) -> GPUArray: + """Native C++ CUDA implementation of layernorm (zero-copy).""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + input_native = input._get_native() + gamma_native = gamma._get_native() + beta_native = beta._get_native() + c_native = native.layernorm(input_native, gamma_native, beta_native, eps) + return GPUArray._wrap_native(c_native) diff --git a/src/pygpukit/scheduler/__init__.py b/src/pygpukit/scheduler/__init__.py index 86eefb3..0583272 100644 --- a/src/pygpukit/scheduler/__init__.py +++ b/src/pygpukit/scheduler/__init__.py @@ -4,6 +4,7 @@ - Memory reservation - Bandwidth pacing - QoS policies (Guaranteed, Burstable, BestEffort) +- Multi-LLM execution contexts (v0.2.6+) """ from pygpukit.scheduler.core import ( @@ -13,6 +14,31 @@ TaskState, ) +# Multi-LLM execution context API (v0.2.6+) +from pygpukit.scheduler.execution import ( + GB, + HAS_MULTI_LLM, + KB, + MB, + # Async execution (v0.2.6+) + AsyncKernelRequest, + ContextStats, + ExecutionContext, + KernelFuture, + KernelResult, + SchedulerStats, + context_session, + create_context, + destroy_context, + get_context, + initialize, + is_session_active, + list_contexts, + reset, + session, + stats, +) + # Rust scheduler (v0.2+) # Import Rust implementation if available try: @@ -47,4 +73,26 @@ "RustSchedulerStats", "RustTaskStats", "HAS_RUST_BACKEND", + # Multi-LLM execution context API (v0.2.6+) + "KB", + "MB", + "GB", + "initialize", + "create_context", + "get_context", + "destroy_context", + "list_contexts", + "session", + "context_session", + "is_session_active", + "stats", + "reset", + "ExecutionContext", + "ContextStats", + "SchedulerStats", + # Async execution (v0.2.6+) + "AsyncKernelRequest", + "KernelFuture", + "KernelResult", + "HAS_MULTI_LLM", ] diff --git a/src/pygpukit/scheduler/execution.py b/src/pygpukit/scheduler/execution.py new file mode 100644 index 0000000..e8a391b --- /dev/null +++ b/src/pygpukit/scheduler/execution.py @@ -0,0 +1,718 @@ +"""Multi-LLM Execution Context Management. + +Provides execution context management for running multiple LLM instances +concurrently on a single GPU with stream-based isolation. + +Example: + >>> from pygpukit.scheduler import create_context, session + >>> + >>> # Create execution contexts for two LLMs + >>> ctx1 = create_context("gpt2_a", max_vram=4 * GB) + >>> ctx2 = create_context("gpt2_b", max_vram=4 * GB) + >>> + >>> # Run both LLMs in a session + >>> with session(): + ... llm1.generate("Hello") + ... llm2.generate("World") +""" + +from __future__ import annotations + +from collections.abc import Generator +from contextlib import contextmanager +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + +# Constants +KB = 1024 +MB = 1024 * KB +GB = 1024 * MB + +# Try to import Rust backend +_controller = None + +try: + import _pygpukit_rust._pygpukit_rust as _rust + + _MultiLLMController = _rust.MultiLLMController + _ContextState = _rust.ContextState + _ContextStats = _rust.ContextStats + _ControllerStats = _rust.ControllerStats + # Async types + _FutureState = _rust.FutureState + _KernelFuture = _rust.KernelFuture + _KernelResult = _rust.KernelResult + _AsyncKernelRequest = _rust.AsyncKernelRequest + _AsyncExecStats = _rust.AsyncExecStats + HAS_MULTI_LLM = True +except ImportError: + _MultiLLMController = None + _ContextState = None + _ContextStats = None + _ControllerStats = None + _FutureState = None + _KernelFuture = None + _KernelResult = None + _AsyncKernelRequest = None + _AsyncExecStats = None + HAS_MULTI_LLM = False + + +def _get_controller(): + """Get or create the global controller instance.""" + global _controller + if _controller is None: + if not HAS_MULTI_LLM: + raise RuntimeError( + "Multi-LLM scheduler requires Rust backend. " + "Please rebuild PyGPUkit with Rust support." + ) + _controller = _MultiLLMController() + return _controller + + +def initialize( + device_id: int = 0, + device_total_memory: int = 0, + total_vram_budget: int = 0, +) -> None: + """Initialize the multi-LLM scheduler. + + This must be called before creating execution contexts. + If not called explicitly, it will be called automatically + with default parameters on first context creation. + + Args: + device_id: CUDA device ID (default 0) + device_total_memory: Total device memory in bytes (0 = auto-detect) + total_vram_budget: Total VRAM budget for all contexts (0 = device total) + """ + controller = _get_controller() + if not controller.is_initialized(): + controller.initialize(device_id, device_total_memory, total_vram_budget) + + +def create_context( + llm_id: str, + max_vram: int = 0, + *, + device_id: int = 0, +) -> ExecutionContext: + """Create an execution context for an LLM instance. + + Each LLM must have exactly one execution context. The context + provides a dedicated CUDA stream for kernel isolation and + optional VRAM budget tracking. + + Args: + llm_id: Unique identifier for the LLM instance + max_vram: Maximum VRAM budget for this LLM in bytes (0 = share global budget) + device_id: CUDA device ID (used for auto-initialization) + + Returns: + ExecutionContext for the LLM + + Raises: + RuntimeError: If context already exists for llm_id + + Example: + >>> ctx = create_context("gpt2_a", max_vram=4 * GB) + >>> print(ctx.stream_id) + 0 + """ + controller = _get_controller() + + # Auto-initialize if needed + if not controller.is_initialized(): + initialize(device_id=device_id) + + stream_id = controller.create_context(llm_id, max_vram) + return ExecutionContext(llm_id, stream_id, max_vram) + + +def get_context(llm_id: str) -> ExecutionContext | None: + """Get an existing execution context by LLM ID. + + Args: + llm_id: LLM identifier + + Returns: + ExecutionContext if found, None otherwise + """ + controller = _get_controller() + stats = controller.get_context(llm_id) + if stats is None: + return None + return ExecutionContext(stats.llm_id, stats.stream_id, stats.max_vram) + + +def destroy_context(llm_id: str) -> bool: + """Destroy an execution context. + + Args: + llm_id: LLM identifier + + Returns: + True if context was destroyed, False if not found + """ + controller = _get_controller() + return controller.destroy_context(llm_id) + + +def list_contexts() -> list[str]: + """List all active context LLM IDs. + + Returns: + List of LLM identifiers + """ + controller = _get_controller() + return controller.list_contexts() + + +@contextmanager +def session() -> Generator[None, None, None]: + """Context manager for a multi-LLM session. + + Within a session, all contexts are marked as running. + When the session ends, all contexts are synchronized + and marked as idle. + + Example: + >>> with session(): + ... llm1.generate("Hello") + ... llm2.generate("World") + ... # All streams synchronized here + """ + controller = _get_controller() + controller.start_session() + try: + yield + finally: + controller.end_session() + + +class context_session: + """Context manager for a per-context session. + + Supports both sync `with` and async `async with`. + + Unlike the global session(), this starts a session for a specific context, + allowing independent LLM execution. Each context can have its own + session lifecycle. + + Args: + ctx: The ExecutionContext to start a session for + + Example (sync): + >>> with context_session(tts_ctx), context_session(llm_ctx): + ... tts_future = tts_ctx.dispatch_async(tts_request) + ... llm_future = llm_ctx.dispatch_async(llm_request) + ... text = llm_future.wait() + ... audio = tts_future.wait() + + Example (async): + >>> async with context_session(llm_ctx), context_session(tts_ctx): + ... llm_f = llm_ctx.dispatch_async(llm_req) + ... tts_f = tts_ctx.dispatch_async(tts_req) + ... text, audio = await asyncio.gather(llm_f, tts_f) + """ + + def __init__(self, ctx: ExecutionContext): + self._ctx = ctx + + # Sync context manager + def __enter__(self) -> None: + self._ctx.start_session() + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self._ctx.end_session() + + # Async context manager + async def __aenter__(self) -> None: + self._ctx.start_session() + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self._ctx.end_session() + + +def is_session_active() -> bool: + """Check if a session is currently active. + + Returns: + True if session is active + """ + controller = _get_controller() + return controller.is_session_active() + + +def stats() -> SchedulerStats: + """Get scheduler statistics. + + Returns: + SchedulerStats object with current state + """ + controller = _get_controller() + return SchedulerStats(controller.stats()) + + +def reset() -> None: + """Reset the scheduler, destroying all contexts.""" + controller = _get_controller() + controller.reset() + + +class AsyncKernelRequest: + """Request for async kernel dispatch. + + Use this to specify kernel dispatch parameters. + + Example: + >>> request = AsyncKernelRequest.linear(kernel_handle, 1024, 256) + >>> future = ctx.dispatch_async(request) + """ + + def __init__(self, kernel_handle: int): + """Create a new async kernel request. + + Args: + kernel_handle: Kernel function handle (CUfunction as int) + """ + if not HAS_MULTI_LLM: + raise RuntimeError("Multi-LLM scheduler requires Rust backend.") + self._inner = _AsyncKernelRequest(kernel_handle) + + @classmethod + def linear( + cls, kernel_handle: int, n_elements: int, block_size: int = 256 + ) -> AsyncKernelRequest: + """Create a linear kernel request (1D grid). + + Args: + kernel_handle: Kernel function handle + n_elements: Number of elements to process + block_size: Threads per block (default 256) + """ + if not HAS_MULTI_LLM: + raise RuntimeError("Multi-LLM scheduler requires Rust backend.") + obj = cls.__new__(cls) + obj._inner = _AsyncKernelRequest.linear(kernel_handle, n_elements, block_size) + return obj + + def with_grid(self, x: int, y: int = 1, z: int = 1) -> AsyncKernelRequest: + """Set grid dimensions.""" + new_obj = AsyncKernelRequest.__new__(AsyncKernelRequest) + new_obj._inner = self._inner.with_grid(x, y, z) + return new_obj + + def with_block(self, x: int, y: int = 1, z: int = 1) -> AsyncKernelRequest: + """Set block dimensions.""" + new_obj = AsyncKernelRequest.__new__(AsyncKernelRequest) + new_obj._inner = self._inner.with_block(x, y, z) + return new_obj + + def with_shared_mem(self, bytes: int) -> AsyncKernelRequest: + """Set shared memory size.""" + new_obj = AsyncKernelRequest.__new__(AsyncKernelRequest) + new_obj._inner = self._inner.with_shared_mem(bytes) + return new_obj + + def with_args(self, args: list[int]) -> AsyncKernelRequest: + """Set kernel arguments (as list of u64 pointers).""" + new_obj = AsyncKernelRequest.__new__(AsyncKernelRequest) + new_obj._inner = self._inner.with_args(args) + return new_obj + + @property + def kernel_handle(self) -> int: + return self._inner.kernel_handle + + @property + def grid(self) -> tuple[int, int, int]: + return self._inner.grid + + @property + def block(self) -> tuple[int, int, int]: + return self._inner.block + + def __repr__(self) -> str: + return f"AsyncKernelRequest(handle=0x{self.kernel_handle:x}, grid={self.grid}, block={self.block})" + + +class KernelFuture: + """Handle for tracking async kernel execution. + + Supports both synchronous `wait()` and Python asyncio `await`. + + Example (sync): + >>> future = ctx.dispatch_async(request) + >>> result = future.wait() # Blocking + + Example (async): + >>> async with context_session(ctx): + ... future = ctx.dispatch_async(request) + ... result = await future # Non-blocking in event loop + """ + + def __init__(self, inner: Any): + self._inner = inner + + @property + def id(self) -> int: + """Future ID.""" + return self._inner.id + + @property + def stream_id(self) -> int: + """Stream ID where kernel is executing.""" + return self._inner.stream_id + + @property + def context_id(self) -> str: + """Context ID (LLM ID).""" + return self._inner.context_id + + @property + def state(self) -> str: + """Current state: 'PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED'.""" + state_val = self._inner.state + if hasattr(state_val, "value"): + state_val = state_val.value + return ["PENDING", "RUNNING", "COMPLETED", "FAILED", "CANCELLED"][state_val] + + def is_ready(self) -> bool: + """Check if kernel execution is complete (non-blocking).""" + return self._inner.is_ready() + + def wait(self) -> KernelResult: + """Wait for kernel completion (blocking). + + Returns the kernel result. If already complete, returns immediately. + If still running, blocks until completion. + """ + return KernelResult(self._inner.wait()) + + def try_get(self) -> KernelResult | None: + """Try to get result without blocking. + + Returns None if not ready yet. + """ + result = self._inner.try_get() + return KernelResult(result) if result is not None else None + + def exec_time(self) -> float | None: + """Get execution time (if completed).""" + return self._inner.exec_time() + + def _wait_sync(self) -> KernelResult: + """Synchronous wait (for executor).""" + return KernelResult(self._inner.wait()) + + def __await__(self) -> Generator[Any, None, KernelResult]: + """Make KernelFuture awaitable for asyncio. + + Uses run_in_executor to avoid blocking the event loop. + The blocking `wait()` runs in the default ThreadPoolExecutor. + + Example: + >>> result = await future + """ + import asyncio + + loop = asyncio.get_running_loop() + return loop.run_in_executor(None, self._wait_sync).__await__() + + def __repr__(self) -> str: + return f"KernelFuture(id={self.id}, context='{self.context_id}', state={self.state})" + + +class KernelResult: + """Result of an async kernel execution.""" + + def __init__(self, inner: Any): + self._inner = inner + + @property + def success(self) -> bool: + """Whether execution succeeded.""" + return self._inner.success + + @property + def error(self) -> str | None: + """Error message if failed.""" + return self._inner.error + + @property + def exec_time(self) -> float: + """Execution time in seconds.""" + return self._inner.exec_time + + @property + def output(self) -> bytes | None: + """Output data as bytes (if any).""" + return self._inner.output + + def __repr__(self) -> str: + if self.success: + return f"KernelResult(success=True, exec_time={self.exec_time:.4f}s)" + return f"KernelResult(success=False, error='{self.error}')" + + +class ExecutionContext: + """Execution context for an LLM instance. + + Each LLM is bound to exactly one ExecutionContext, which provides: + - Dedicated CUDA stream for kernel isolation + - Optional VRAM budget tracking + - State management (IDLE, RUNNING, PAUSED) + - Async kernel dispatch with KernelFuture + + Do not instantiate directly; use `create_context()` instead. + + Example: + >>> ctx = create_context("gpt2", max_vram=4 * GB) + >>> + >>> # Async execution + >>> request = AsyncKernelRequest.linear(kernel_handle, 1024, 256) + >>> future = ctx.dispatch_async(request) + >>> # Do other work... + >>> result = future.wait() + """ + + def __init__(self, llm_id: str, stream_id: int, max_vram: int): + self._llm_id = llm_id + self._stream_id = stream_id + self._max_vram = max_vram + + @property + def llm_id(self) -> str: + """LLM identifier.""" + return self._llm_id + + @property + def stream_id(self) -> int: + """Assigned CUDA stream ID.""" + return self._stream_id + + @property + def max_vram(self) -> int: + """Maximum VRAM budget in bytes (0 = unlimited).""" + return self._max_vram + + @property + def stats(self) -> ContextStats | None: + """Get current context statistics.""" + controller = _get_controller() + rust_stats = controller.get_context(self._llm_id) + if rust_stats is None: + return None + return ContextStats(rust_stats) + + def track_allocation(self, buffer_id: int, size: int) -> bool: + """Track a memory allocation. + + Args: + buffer_id: Unique buffer identifier + size: Size in bytes + + Returns: + True if allocation fits within budget + """ + controller = _get_controller() + return controller.track_allocation(self._llm_id, buffer_id, size) + + def track_deallocation(self, buffer_id: int) -> None: + """Track a memory deallocation. + + Args: + buffer_id: Buffer identifier + """ + controller = _get_controller() + controller.track_deallocation(self._llm_id, buffer_id) + + # --- Async Execution --- + + def dispatch_async(self, request: AsyncKernelRequest) -> KernelFuture: + """Dispatch an async kernel. + + Returns a KernelFuture that can be used to wait for completion. + The kernel is queued for execution on this context's stream. + + Args: + request: Kernel dispatch request + + Returns: + KernelFuture for tracking execution + + Example: + >>> request = AsyncKernelRequest.linear(kernel_handle, 1024, 256) + >>> future = ctx.dispatch_async(request) + >>> # Do other work while kernel executes... + >>> result = future.wait() + """ + controller = _get_controller() + rust_future = controller.dispatch_async(self._llm_id, request._inner) + return KernelFuture(rust_future) + + def start_session(self) -> None: + """Start a per-context session. + + Unlike the global session(), per-context sessions allow independent + LLM execution. Each context can have its own session lifecycle. + + Example: + >>> tts_ctx.start_session() + >>> llm_ctx.start_session() + >>> + >>> # Both run independently + >>> tts_future = tts_ctx.dispatch_async(tts_request) + >>> llm_future = llm_ctx.dispatch_async(llm_request) + >>> + >>> # Wait for results in any order + >>> llm_result = llm_future.wait() + >>> tts_result = tts_future.wait() + """ + controller = _get_controller() + controller.start_context_session(self._llm_id) + + def end_session(self) -> None: + """End the per-context session.""" + controller = _get_controller() + controller.end_context_session(self._llm_id) + + def is_session_active(self) -> bool: + """Check if a session is active for this context.""" + controller = _get_controller() + result = controller.is_context_session_active(self._llm_id) + return result if result is not None else False + + def destroy(self) -> bool: + """Destroy this context. + + Returns: + True if context was destroyed + """ + return destroy_context(self._llm_id) + + def __repr__(self) -> str: + return f"ExecutionContext(llm_id='{self._llm_id}', stream_id={self._stream_id})" + + +class ContextStats: + """Statistics for an execution context.""" + + def __init__(self, rust_stats: Any): + self._inner = rust_stats + + @property + def llm_id(self) -> str: + return self._inner.llm_id + + @property + def state(self) -> str: + """Current state: 'IDLE', 'RUNNING', or 'PAUSED'.""" + state_val = self._inner.state + if hasattr(state_val, "value"): + state_val = state_val.value + return ["IDLE", "RUNNING", "PAUSED"][state_val] + + @property + def stream_id(self) -> int: + return self._inner.stream_id + + @property + def max_vram(self) -> int: + return self._inner.max_vram + + @property + def used_vram(self) -> int: + return self._inner.used_vram + + @property + def available_vram(self) -> int: + return self._inner.available_vram + + @property + def buffer_count(self) -> int: + return self._inner.buffer_count + + def __repr__(self) -> str: + return ( + f"ContextStats(llm_id='{self.llm_id}', state={self.state}, used_vram={self.used_vram})" + ) + + +class SchedulerStats: + """Statistics for the multi-LLM scheduler.""" + + def __init__(self, rust_stats: Any): + self._inner = rust_stats + + @property + def initialized(self) -> bool: + return self._inner.initialized + + @property + def device_id(self) -> int: + return self._inner.device_id + + @property + def total_vram_budget(self) -> int: + return self._inner.total_vram_budget + + @property + def device_total_memory(self) -> int: + return self._inner.device_total_memory + + @property + def used_vram(self) -> int: + return self._inner.used_vram + + @property + def available_vram(self) -> int: + return self._inner.available_vram + + @property + def context_count(self) -> int: + return self._inner.context_count + + @property + def stream_pool_size(self) -> int: + return self._inner.stream_pool_size + + def __repr__(self) -> str: + return ( + f"SchedulerStats(contexts={self.context_count}, " + f"used_vram={self.used_vram}, available_vram={self.available_vram})" + ) + + +# Export constants +__all__ = [ + # Constants + "KB", + "MB", + "GB", + # Functions + "initialize", + "create_context", + "get_context", + "destroy_context", + "list_contexts", + "session", + "context_session", + "is_session_active", + "stats", + "reset", + # Classes + "ExecutionContext", + "ContextStats", + "SchedulerStats", + # Async execution classes + "AsyncKernelRequest", + "KernelFuture", + "KernelResult", + # Feature flag + "HAS_MULTI_LLM", +] diff --git a/tests/test_3090ti_performance.py b/tests/test_3090ti_performance.py index 1812592..af879ef 100644 --- a/tests/test_3090ti_performance.py +++ b/tests/test_3090ti_performance.py @@ -200,7 +200,11 @@ class TestCorrectness: """Verify correctness is maintained with optimizations.""" def test_matmul_correctness_small(self, check_3090ti): - """Small matmul should be numerically correct.""" + """Small matmul should be numerically correct. + + Note: v0.2.6+ uses CUTLASS TF32 by default, which has 19-bit mantissa + (vs FP32's 23-bit). Expected relative error is ~1e-3. + """ A = np.random.randn(256, 256).astype(np.float32) B = np.random.randn(256, 256).astype(np.float32) @@ -211,10 +215,14 @@ def test_matmul_correctness_small(self, check_3090ti): C_expected = A @ B rel_error = np.max(np.abs(C_result - C_expected)) / np.max(np.abs(C_expected)) - assert rel_error < 1e-5, f"Relative error too high: {rel_error}" + # TF32 threshold: 1e-3 (19-bit mantissa precision) + assert rel_error < 1e-3, f"Relative error too high: {rel_error}" def test_matmul_correctness_large(self, check_3090ti): - """Large matmul should be numerically correct.""" + """Large matmul should be numerically correct. + + Note: v0.2.6+ uses CUTLASS TF32 by default. + """ A = np.random.randn(4096, 4096).astype(np.float32) B = np.random.randn(4096, 4096).astype(np.float32) @@ -225,10 +233,14 @@ def test_matmul_correctness_large(self, check_3090ti): C_expected = A @ B rel_error = np.max(np.abs(C_result - C_expected)) / np.max(np.abs(C_expected)) - assert rel_error < 1e-4, f"Relative error too high: {rel_error}" + # TF32 threshold: 1e-3 + assert rel_error < 1e-3, f"Relative error too high: {rel_error}" def test_matmul_correctness_non_square(self, check_3090ti): - """Non-square matmul should be numerically correct.""" + """Non-square matmul should be numerically correct. + + Note: v0.2.6+ uses CUTLASS TF32 by default. + """ A = np.random.randn(2048, 1024).astype(np.float32) B = np.random.randn(1024, 4096).astype(np.float32) @@ -239,7 +251,8 @@ def test_matmul_correctness_non_square(self, check_3090ti): C_expected = A @ B rel_error = np.max(np.abs(C_result - C_expected)) / np.max(np.abs(C_expected)) - assert rel_error < 1e-4, f"Relative error too high: {rel_error}" + # TF32 threshold: 1e-3 + assert rel_error < 1e-3, f"Relative error too high: {rel_error}" class TestEfficiencyMetrics: diff --git a/third_party/cutlass b/third_party/cutlass new file mode 160000 index 0000000..d4e16f5 --- /dev/null +++ b/third_party/cutlass @@ -0,0 +1 @@ +Subproject commit d4e16f5d4e70cd95049e3708cbee01205abe43c0