diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1c24e3a..95fd3d4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -102,7 +102,7 @@ jobs: mkdir -p build && cd build cmake .. \ -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_CUDA_ARCHITECTURES="80;86;89;90;100;120" \ + -DCMAKE_CUDA_ARCHITECTURES="80;86;89;90;100;120a" \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \ -Dpybind11_DIR=$(python -c "import pybind11; print(pybind11.get_cmake_dir())") diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a44c0c0..7063d1e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -127,7 +127,7 @@ jobs: -DCMAKE_BUILD_TYPE=Release \ -DPYBIND11_FINDPYTHON=ON \ -Dpybind11_DIR=$(python -c "import pybind11; print(pybind11.get_cmake_dir())") \ - -DCMAKE_CUDA_ARCHITECTURES="80;86;89;90;100;120" \ + -DCMAKE_CUDA_ARCHITECTURES="80;86;89;90;100;120a" \ -DMODULE_SUFFIX="_cu131" cmake --build . --config Release -j$(nproc) @@ -216,7 +216,7 @@ jobs: env: # Skip native build since we have prebuilt modules PYGPUKIT_SKIP_NATIVE_BUILD: "1" - CMAKE_CUDA_ARCHITECTURES: "80;86;89;90;100;120" + CMAKE_CUDA_ARCHITECTURES: "80;86;89;90;100;120a" - name: Inject prebuilt native modules into wheel run: | @@ -419,7 +419,7 @@ jobs: -DCMAKE_BUILD_TYPE=Release ^ -DPYBIND11_FINDPYTHON=ON ^ -Dpybind11_DIR="%PYBIND11_DIR%" ^ - -DCMAKE_CUDA_ARCHITECTURES="80;86;89;90;100;120" ^ + -DCMAKE_CUDA_ARCHITECTURES="80;86;89;90;100;120a" ^ -DMODULE_SUFFIX="_cu131" cmake --build . --config Release @@ -537,7 +537,7 @@ jobs: set "PYGPUKIT_SKIP_NATIVE_BUILD=1" python -m build --wheel env: - CMAKE_CUDA_ARCHITECTURES: "80;86;89;90;100;120" + CMAKE_CUDA_ARCHITECTURES: "80;86;89;90;100;120a" - name: Inject prebuilt native modules into wheel shell: pwsh diff --git a/.gitmodules b/.gitmodules index 281cb2d..74bb94e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ [submodule "third_party/cutlass"] path = third_party/cutlass - url = https://github.com/NVIDIA/cutlass.git + url = https://github.com/m96-chan/cutlass.git + branch = fix/sm120-alignment diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ff2e0f..b36519b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,24 @@ All notable changes to PyGPUkit will be documented in this file. +## [0.2.15] - 2025-12-26 + +### Added +- **FP8 I/O GEMM (SM120)**: Pure FP8 E4M3 input/output GEMM for FP8 model inference + - `matmul_fp8_fp8_sm120`: FP8 GEMM with unity scaling + - `matmul_fp8_fp8_blockwise_sm120`: FP8 GEMM with per-block scale factors + - `fp8_fp8_get_scale_sizes`: Get required scale factor sizes for (M, N, K) + - `fp8_fp8_sm120_available`: Check SM120 FP8 I/O availability +- **Pure NVF4 GEMM**: GPU-side BF16->NVF4 quantization with 3-stage pipeline (446 TFLOPS) +- **New math operations**: sin, cos, sqrt, rsqrt, abs, neg +- **New comparison operations**: clamp, where +- **New activation functions**: sigmoid, tanh +- **New reduction operations**: argmax, min, sum_axis +- **uint8/int8 NumPy support**: `from_numpy` now supports uint8 and int8 arrays + +### Changed +- Renamed `matmul_fp8_sm120.cu` to `matmul_fp8_fp32_sm120.cu` for clarity (FP8 compute, FP32 output) + ## [0.2.14] - 2025-12-23 ### Fixed diff --git a/CLAUDE.md b/CLAUDE.md index 2212dfd..330e7c6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -35,6 +35,19 @@ The core scheduling, memory management, GPU coordination, and performance-critic ``` PyGPUkit/ ├── src/pygpukit/ # Python API (NumPy-compatible) +│ ├── core/ # GPUArray, backend abstraction +│ ├── ops/ # GPU operations (matmul, nn, audio, etc.) +│ ├── llm/ # LLM inference (Qwen, LLaMA) +│ │ ├── models/ # Model implementations +│ │ └── sampling/ # Token sampling strategies +│ └── asr/ # Speech recognition (Whisper) +│ ├── preprocessing.py # Audio preprocessing (mel, normalize) +│ └── whisper/ # Whisper model implementation +│ ├── config.py # WhisperConfig +│ ├── loader.py # SafeTensors loader +│ ├── encoder.py # Whisper encoder +│ ├── decoder.py # Whisper decoder +│ └── model.py # WhisperModel high-level API ├── native/ │ ├── core/ # C++ (CUDA Runtime/Driver API) │ ├── jit/ # C++ (NVRTC) @@ -48,9 +61,20 @@ PyGPUkit/ │ │ └── device.rs # DeviceCapabilities, KernelType │ └── pygpukit-python/ # PyO3 bindings ├── examples/ +├── benchmarks/ # Performance benchmarks └── tests/ ``` +### Module Separation Policy + +| Module | Purpose | Input | Output | +|--------|---------|-------|--------| +| `llm/` | Text generation | Text tokens | Text tokens | +| `asr/` | Speech recognition | Audio waveform | Text | +| `ops/` | Low-level GPU ops | GPUArray | GPUArray | + +**Rationale**: Modules are separated by **modality** (audio vs text), not by architecture (transformer). This follows industry conventions (HuggingFace, OpenAI API) and enables clean future expansion (TTS, vision, etc.). + ### Language Responsibilities | Component | Language | Reason | @@ -530,7 +554,7 @@ Edit → Build → Validate → Benchmark → Commit cd /d/Projects/m96-chan/PyGPUkit ./build.sh 86 # SM 86のみ (RTX 3090 Ti) ./build.sh 120 # SM 120のみ (RTX 5090) -./build.sh # デフォルト: SM 86 +./build.sh # デフォルト: SM 120a ``` **Windows cmd.exeからビルド(代替):** @@ -939,11 +963,18 @@ accepted_tokens = model.jacobi_decode_step(draft_tokens, position) cd /d/Projects/m96-chan/PyGPUkit ./build.sh 86 # SM 86のみ (RTX 3090 Ti) ./build.sh 120 # SM 120のみ (RTX 5090) -./build.sh # デフォルト: SM 86 +./build.sh # デフォルト: SM 120a ``` **サポートSM:** 80, 86, 89, 90, 100, 120 +### Local Development Hardware + +| Machine | GPU | SM | CUDA Toolkit | Notes | +|---------|-----|-----|--------------|-------| +| Primary | RTX 5090 | 120 | 13.1 | Blackwell GeForce, FP8 testing | +| Secondary | RTX 3090 Ti | 86 | 12.x | Ampere, TF32 benchmarks | + ### Tokenizer **PyGPUkit内蔵のTokenizerは使用しない。HuggingFace `tokenizers`ライブラリを使用する。** diff --git a/README.md b/README.md index 07b2ad3..47462c6 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,90 @@ PyGPUkit aims to be the "micro-runtime for GPU computing": small, fast, and idea --- +## What's New in v0.2.15 + +### FP8 I/O GEMM (SM120) +Pure FP8 input/output GEMM for FP8 model inference (Llama 3.1 FP8, Qwen FP8, etc.): + +| Function | Description | +|----------|-------------| +| `matmul_fp8_fp8_sm120` | FP8 E4M3 input -> FP8 E4M3 output (unity scaling) | +| `matmul_fp8_fp8_blockwise_sm120` | FP8 with block-wise scale_A / scale_B | +| `fp8_fp8_get_scale_sizes` | Get required scale factor sizes for (M, N, K) | +| `fp8_fp8_sm120_available` | Check SM120 FP8 I/O availability | + +```python +import pygpukit as gpk +import numpy as np + +# Check availability +if gpk.fp8_fp8_sm120_available(): + # Get scale sizes for blockwise scaling + sfa_size, sfb_size = gpk.fp8_fp8_get_scale_sizes(M, N, K) + + # Blockwise scaled FP8 GEMM (for real FP8 models) + scale_a = gpk.from_numpy(np.ones(sfa_size, dtype=np.float32)) + scale_b = gpk.from_numpy(np.ones(sfb_size, dtype=np.float32)) + C = gpk.matmul_fp8_fp8_blockwise_sm120(A_fp8, B_fp8, scale_a, scale_b) +``` + +### Pure NVF4 GEMM (446 TFLOPS) +GPU-side BF16->NVF4 quantization with 3-stage pipeline for maximum throughput: + +| Matrix Size | TFLOPS | Notes | +|-------------|--------|-------| +| 8192x8192 | 320 | Branchless vectorized loads | +| 12288x12288 | 400 | 3-stage async pipeline | +| 16384x16384 | **446** | Direct write to user buffer | + +### New Math Operations +Extended math operations for GPU computing: + +| Category | Operations | +|----------|------------| +| **Trigonometric** | `sin`, `cos` | +| **Power/Root** | `sqrt`, `rsqrt` | +| **Sign** | `abs`, `neg` | +| **Comparison** | `clamp`, `where` | +| **Activation** | `sigmoid`, `tanh` | +| **Reduction** | `argmax`, `min`, `sum_axis` | + +```python +import pygpukit as gpk + +# Trigonometric +y = gpk.sin(x) +y = gpk.cos(x) + +# Power operations +y = gpk.sqrt(x) +y = gpk.rsqrt(x) # 1/sqrt(x) + +# Element-wise comparison +y = gpk.clamp(x, min_val=-1.0, max_val=1.0) +y = gpk.where(cond, x, y) # cond ? x : y + +# New activations +y = gpk.sigmoid(x) +y = gpk.tanh(x) + +# New reductions +idx = gpk.argmax(x) # Index of maximum +val = gpk.min(x) # Minimum value +y = gpk.sum_axis(x, 1) # Sum along axis +``` + +### uint8/int8 NumPy Support +`from_numpy` now supports uint8 and int8 arrays for FP8 data handling: + +```python +# FP8 data stored as uint8 +fp8_data = np.array([...], dtype=np.uint8) +gpu_fp8 = gpk.from_numpy(fp8_data) +``` + +--- + ## What's New in v0.2.14 ### Packaging Fixes @@ -43,10 +127,10 @@ v0.2.13 and v0.2.14 fix wheel RECORD file issues that caused PyPI deprecation wa | v0.2.14 | Windows wheel missing `licenses/LICENSE` in RECORD | Added `-Recurse` to scan dist-info subdirectories | | v0.2.13 | Hardcoded version in release workflow | Dynamic dist-info folder detection | -**Recommended:** Use v0.2.14 or later. +**Recommended:** Use v0.2.15 or later. ```bash -pip install pygpukit>=0.2.14 +pip install pygpukit>=0.2.15 ``` --- @@ -530,6 +614,37 @@ print(f"NVRTC Path: {gp.get_nvrtc_path()}") # Path to NVRTC DLL (if available) > **Note:** CUTLASS is automatic for compatible sizes (16-aligned). Use `PYGPUKIT_NO_TF32=1` for full FP32 precision. +### GEMV Performance (RTX 5090, SM120a) + +For LLM decode (M=1), custom GEMV kernels significantly outperform cuBLASLt: + +| Model Layer | K | N | cuBLASLt | BF16 GEMV | NVF4 GEMV | Memory | +|-------------|------|-------|----------|-----------|-----------|--------| +| Qwen-7B hidden | 4096 | 4096 | 413us | **97us** | 152us | 73% less | +| Qwen-7B MLP | 4096 | 11008 | 418us | **96us** | 153us | 73% less | +| Qwen-72B hidden | 8192 | 8192 | 799us | 266us | **265us** | 73% less | +| Qwen-72B MLP | 8192 | 29568 | 1603us | **375us** | 454us | 73% less | + +| Kernel | Description | Use Case | +|--------|-------------|----------| +| **BF16 GEMV** | Custom BF16 kernel optimized for M=1 | Speed priority | +| **NVF4 GEMV** | 4-bit NVF4 weights with block scaling | Memory priority (73% reduction) | + +> **Note:** For large K (8192+), NVF4 matches BF16 speed while using 73% less memory. Ideal for memory-constrained LLM inference. + +### NVF4-BF16 GEMM Performance (RTX 5090, SM120a) + +4-bit NVF4 GEMM with BF16 I/O using CUTLASS block-scaled tensor operations: + +| Matrix Size | TFLOPS (median) | TFLOPS (max) | Time (ms) | +|-------------|-----------------|--------------|-----------| +| 4096×4096 | 53 | 55 | 2.6 | +| 8192×8192 | 141 | 143 | 7.8 | +| 12288×12288 | 201 | 216 | 18.5 | +| 16384×16384 | **246** | **252** | 35.8 | + +> **Note:** GPU-side BF16→NVF4 quantization with unit scaling. No host-device copies. Ideal for memory-bound LLM inference with 4x bandwidth reduction vs BF16. + --- ## Installation @@ -695,6 +810,7 @@ PyGPUkit/ | **v0.2.10** | **Dynamic cuBLASLt loading**, CUDA Graph optimizations, descriptor caching | | **v0.2.11** | **Batch decode** (6.8x speedup), Decode Strategy framework, Driver API async, Dual CUDA builds, RTX 5090 (SM120) | | **v0.2.12** | **Advanced audio processing** (ISTFT, Griffin-Lim, HPSS, CQT, pitch detection, time stretch) | +| **v0.2.15** | **FP8 I/O GEMM** (blockwise scaling), Pure NVF4 (446 TFLOPS), New math ops (sin, cos, sqrt, rsqrt, abs, neg, clamp, where, sigmoid, tanh, argmax, min, sum_axis) | ### Planned diff --git a/benchmarks/benchmark_nvf4_bf16.py b/benchmarks/benchmark_nvf4_bf16.py new file mode 100644 index 0000000..2a5213b --- /dev/null +++ b/benchmarks/benchmark_nvf4_bf16.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +""" +NVF4-BF16 GEMM Benchmark for SM120 (Blackwell GeForce) + +Benchmarks NVF4 (4-bit) GEMM with BF16 I/O. +NVF4 provides 2x memory bandwidth compared to FP8. +""" + +import time + +import numpy as np + + +def bf16_to_f32(bf16_uint16: np.ndarray) -> np.ndarray: + """Convert BFloat16 (stored as uint16) to float32.""" + bf16_uint16 = bf16_uint16.astype(np.uint16) + f32_bits = bf16_uint16.astype(np.uint32) << 16 + return f32_bits.view(np.float32) + + +def f32_to_bf16(f32: np.ndarray) -> np.ndarray: + """Convert float32 to BFloat16 (stored as uint16).""" + f32 = f32.astype(np.float32) + f32_bits = f32.view(np.uint32) + bf16_bits = (f32_bits >> 16).astype(np.uint16) + return bf16_bits + + +def benchmark_nvf4_bf16(sizes: list[int], warmup: int = 5, iterations: int = 20): + """Benchmark NVF4-BF16 GEMM at various sizes.""" + from pygpukit.core.backend import get_native_module + from pygpukit.core.factory import from_numpy + from pygpukit.ops import matmul_nvf4_bf16_sm120, nvf4_bf16_sm120_available + + native = get_native_module() + + if not nvf4_bf16_sm120_available(): + print("NVF4-BF16 SM120 not available") + return + + print("=" * 70) + print("NVF4-BF16 GEMM Benchmark (SM120 Blackwell GeForce)") + print("=" * 70) + + # Get GPU info + props = native.get_device_properties(0) + print(f"GPU: {props.name}") + print(f"SM: {props.compute_capability_major}.{props.compute_capability_minor}") + print() + print("GPU-side quantization: BF16 -> NVF4 (no H2D copies)") + print() + + results = [] + + for size in sizes: + M, N, K = size, size, size + flops = 2.0 * M * N * K # FLOPs for GEMM + + # Create NVF4-appropriate data (values in representable range) + nvf4_values = np.array([0.5, 1.0, 1.5, 2.0, 3.0, 4.0], dtype=np.float32) + A = np.random.choice(nvf4_values, size=(M, K)).astype(np.float32) + B = np.random.choice(nvf4_values, size=(K, N)).astype(np.float32) + + A_bf16 = f32_to_bf16(A) + B_bf16 = f32_to_bf16(B) + + A_gpu = from_numpy(A_bf16) + B_gpu = from_numpy(B_bf16) + + # Warmup + for _ in range(warmup): + C_gpu = matmul_nvf4_bf16_sm120(A_gpu, B_gpu) + native.device_synchronize() + + # Benchmark + times = [] + for _ in range(iterations): + native.device_synchronize() + start = time.perf_counter() + C_gpu = matmul_nvf4_bf16_sm120(A_gpu, B_gpu) + native.device_synchronize() + end = time.perf_counter() + times.append(end - start) + + # Get result and verify + C_uint16 = C_gpu.to_numpy() + C_f32 = bf16_to_f32(C_uint16) + C_ref = bf16_to_f32(A_bf16) @ bf16_to_f32(B_bf16) + + rel_error = np.linalg.norm(C_f32 - C_ref) / np.linalg.norm(C_ref) + + median_time = np.median(times) + min_time = np.min(times) + tflops_median = flops / median_time / 1e12 + tflops_max = flops / min_time / 1e12 + + results.append( + { + "size": size, + "tflops_median": tflops_median, + "tflops_max": tflops_max, + "time_ms": median_time * 1000, + "rel_error": rel_error, + } + ) + + status = "PASS" if rel_error < 0.05 else "FAIL" + print( + f"{M}x{N}x{K}: {tflops_median:.2f} TFLOPS (median), " + f"{tflops_max:.2f} TFLOPS (max), " + f"rel_error={rel_error:.2e} [{status}]" + ) + + print() + print("=" * 70) + print("Summary Table (for README)") + print("=" * 70) + print("| Size | TFLOPS (median) | TFLOPS (max) | Time (ms) |") + print("|------|-----------------|--------------|-----------|") + for r in results: + print( + f"| {r['size']}x{r['size']} | {r['tflops_median']:.2f} | " + f"{r['tflops_max']:.2f} | {r['time_ms']:.2f} |" + ) + + return results + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="NVF4-BF16 GEMM Benchmark") + parser.add_argument( + "--sizes", + nargs="+", + type=int, + default=[1024, 2048, 4096, 8192], + help="Matrix sizes to benchmark", + ) + parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") + parser.add_argument("--iterations", type=int, default=20, help="Number of benchmark iterations") + + args = parser.parse_args() + + benchmark_nvf4_bf16(args.sizes, args.warmup, args.iterations) diff --git a/benchmarks/benchmark_nvf4_nvf4.py b/benchmarks/benchmark_nvf4_nvf4.py new file mode 100644 index 0000000..6ff909d --- /dev/null +++ b/benchmarks/benchmark_nvf4_nvf4.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +""" +Pure NVF4 GEMM Benchmark for SM120 (Blackwell GeForce) + +Benchmarks NVF4 GEMM without quantization overhead to measure +pure tensor core performance. +""" + +import time + +import numpy as np + + +def benchmark_nvf4_nvf4(sizes: list[int], warmup: int = 5, iterations: int = 20): + """Benchmark pure NVF4 GEMM at various sizes.""" + from pygpukit.core.backend import get_native_module + from pygpukit.core.factory import zeros + + native = get_native_module() + + if not native.nvf4_nvf4_sm120_available(): + print("NVF4-NVF4 SM120 not available") + return + + print("=" * 70) + print("Pure NVF4 GEMM Benchmark (SM120 Blackwell GeForce)") + print("=" * 70) + + # Get GPU info + props = native.get_device_properties(0) + print(f"GPU: {props.name}") + print(f"SM: {props.compute_capability_major}.{props.compute_capability_minor}") + print() + print("Pre-quantized NVF4 data (no quantization overhead)") + print() + + results = [] + + for size in sizes: + M, N, K = size, size, size + flops = 2.0 * M * N * K # FLOPs for GEMM + + # Allocate output buffer (BF16) + D_gpu = zeros((M, N), dtype="bfloat16") + D_native = D_gpu._get_native() # Get native GPUArray + + # Warmup + for _ in range(warmup): + native.benchmark_gemm_nvf4_sm120(D_native, M, N, K) + native.device_synchronize() + + # Benchmark + times = [] + for _ in range(iterations): + native.device_synchronize() + start = time.perf_counter() + native.benchmark_gemm_nvf4_sm120(D_native, M, N, K) + native.device_synchronize() + end = time.perf_counter() + times.append(end - start) + + median_time = np.median(times) + min_time = np.min(times) + tflops_median = flops / median_time / 1e12 + tflops_max = flops / min_time / 1e12 + + results.append( + { + "size": size, + "tflops_median": tflops_median, + "tflops_max": tflops_max, + "time_ms": median_time * 1000, + } + ) + + print( + f"{M}x{N}x{K}: {tflops_median:.2f} TFLOPS (median), " + f"{tflops_max:.2f} TFLOPS (max), " + f"time={median_time * 1000:.2f}ms" + ) + + print() + print("=" * 70) + print("Summary Table") + print("=" * 70) + print("| Size | TFLOPS (median) | TFLOPS (max) | Time (ms) |") + print("|------|-----------------|--------------|-----------|") + for r in results: + print( + f"| {r['size']}x{r['size']} | {r['tflops_median']:.2f} | " + f"{r['tflops_max']:.2f} | {r['time_ms']:.2f} |" + ) + + return results + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Pure NVF4 GEMM Benchmark") + parser.add_argument( + "--sizes", + nargs="+", + type=int, + default=[1024, 2048, 4096, 8192, 12288, 16384], + help="Matrix sizes to benchmark", + ) + parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") + parser.add_argument("--iterations", type=int, default=20, help="Number of benchmark iterations") + + args = parser.parse_args() + + benchmark_nvf4_nvf4(args.sizes, args.warmup, args.iterations) diff --git a/build.sh b/build.sh index 1702886..7ef337f 100644 --- a/build.sh +++ b/build.sh @@ -3,18 +3,19 @@ # Usage: ./build.sh [SM_VERSION] [CUDA_VERSION] [MODULE_SUFFIX] # # Examples: -# ./build.sh 120 # SM 120, CUDA 12.9 (default) -# ./build.sh 86 # SM 86, CUDA 12.9 -# ./build.sh 120 13.1 # SM 120, CUDA 13.1 +# ./build.sh 120 # SM 120, CUDA 13.1 (default) +# ./build.sh 86 # SM 86, CUDA 13.1 +# ./build.sh 120 12.9 # SM 120, CUDA 12.9 # ./build.sh 86 12.4 # SM 86, CUDA 12.4 -# ./build.sh 120 12.9 _cu129 # SM 120, CUDA 12.9, module suffix _cu129 +# ./build.sh 120 13.1 _cu131 # SM 120, CUDA 13.1, module suffix _cu131 # -# Supported SM versions: 80, 86, 89, 90, 100, 120 +# Supported SM versions: 80, 86, 89, 90, 100, 120, 120a +# Note: Use 120a for full SM120 accelerated features (tensor cores, block-scaled MMA) # Supported CUDA versions: 12.4, 12.9, 13.1 # Module suffix: _cu129, _cu131, or empty for default name -SM_VERSION=${1:-120} -CUDA_VERSION=${2:-12.9} +SM_VERSION=${1:-120a} +CUDA_VERSION=${2:-13.1} MODULE_SUFFIX=${3:-} echo "=== PyGPUkit Build (Git Bash) ===" @@ -44,7 +45,6 @@ set CUDACXX=%CUDA_PATH%\bin\nvcc.exe set CMAKE_CUDA_COMPILER=%CUDA_PATH%\bin\nvcc.exe set CMAKE_ARGS=-DCMAKE_CUDA_ARCHITECTURES=${SM_VERSION} set PYGPUKIT_MODULE_SUFFIX=${MODULE_SUFFIX} -set PYGPUKIT_DISABLE_CUTLASS=1 pip install -e . --no-build-isolation EOFBAT diff --git a/docs/api.md b/docs/api.md index 06c49ee..2593245 100644 --- a/docs/api.md +++ b/docs/api.md @@ -186,11 +186,89 @@ def log(a: GPUArray) -> GPUArray: """Element-wise natural logarithm: ln(x)""" ``` +### sin + +```python +def sin(a: GPUArray) -> GPUArray: + """Element-wise sine: sin(x)""" +``` + +### cos + +```python +def cos(a: GPUArray) -> GPUArray: + """Element-wise cosine: cos(x)""" +``` + +### sqrt + +```python +def sqrt(a: GPUArray) -> GPUArray: + """Element-wise square root: sqrt(x)""" +``` + +### rsqrt + +```python +def rsqrt(a: GPUArray) -> GPUArray: + """Element-wise reciprocal square root: 1/sqrt(x)""" +``` + +### abs + +```python +def abs(a: GPUArray) -> GPUArray: + """Element-wise absolute value: |x|""" +``` + +### neg + +```python +def neg(a: GPUArray) -> GPUArray: + """Element-wise negation: -x""" +``` + +**Example:** +```python +a = gpk.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) +b = gpk.exp(a) # [e^1, e^2, e^3] +c = gpk.log(a) # [0, ln(2), ln(3)] +d = gpk.sin(a) # [sin(1), sin(2), sin(3)] +e = gpk.cos(a) # [cos(1), cos(2), cos(3)] +f = gpk.sqrt(a) # [1, 1.414, 1.732] +g = gpk.rsqrt(a) # [1, 0.707, 0.577] +``` + +--- + +## Comparison Operations + +### clamp + +```python +def clamp(a: GPUArray, min_val: float, max_val: float) -> GPUArray: + """Clamp values to range [min_val, max_val].""" +``` + +### where + +```python +def where(cond: GPUArray, x: GPUArray, y: GPUArray) -> GPUArray: + """Element-wise conditional: cond ? x : y""" +``` + **Example:** ```python +x = gpk.from_numpy(np.array([-2.0, 0.5, 3.0], dtype=np.float32)) + +# Clamp to [-1, 1] +y = gpk.clamp(x, -1.0, 1.0) # [-1.0, 0.5, 1.0] + +# Conditional selection +cond = gpk.from_numpy(np.array([1.0, 0.0, 1.0], dtype=np.float32)) a = gpk.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) -b = gpk.exp(a) # [e^1, e^2, e^3] -c = gpk.log(a) # [0, ln(2), ln(3)] +b = gpk.from_numpy(np.array([4.0, 5.0, 6.0], dtype=np.float32)) +result = gpk.where(cond, a, b) # [1.0, 5.0, 3.0] ``` --- @@ -211,11 +289,27 @@ def gelu(a: GPUArray) -> GPUArray: """GELU activation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))""" ``` +### sigmoid + +```python +def sigmoid(a: GPUArray) -> GPUArray: + """Sigmoid activation: 1 / (1 + exp(-x))""" +``` + +### tanh + +```python +def tanh(a: GPUArray) -> GPUArray: + """Hyperbolic tangent activation: tanh(x)""" +``` + **Example:** ```python x = gpk.from_numpy(np.array([-1.0, 0.0, 1.0, 2.0], dtype=np.float32)) -y_relu = gpk.relu(x) # [0, 0, 1, 2] -y_gelu = gpk.gelu(x) # [-0.159, 0, 0.841, 1.955] +y_relu = gpk.relu(x) # [0, 0, 1, 2] +y_gelu = gpk.gelu(x) # [-0.159, 0, 0.841, 1.955] +y_sigmoid = gpk.sigmoid(x) # [0.269, 0.5, 0.731, 0.881] +y_tanh = gpk.tanh(x) # [-0.762, 0, 0.762, 0.964] ``` --- @@ -305,16 +399,52 @@ def max(a: GPUArray) -> GPUArray: """Maximum element.""" ``` +### min + +```python +def min(a: GPUArray) -> GPUArray: + """Minimum element.""" +``` + +### argmax + +```python +def argmax(a: GPUArray) -> GPUArray: + """Index of maximum element.""" +``` + +### sum_axis + +```python +def sum_axis(a: GPUArray, axis: int) -> GPUArray: + """Sum along specified axis. + + Args: + a: Input array + axis: Axis to reduce (0 for rows, 1 for columns) + + Returns: + Reduced array with axis removed + """ +``` + **Example:** ```python a = gpk.from_numpy(np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) -total = gpk.sum(a) # [10.0] -avg = gpk.mean(a) # [2.5] -maximum = gpk.max(a) # [4.0] +total = gpk.sum(a) # [10.0] +avg = gpk.mean(a) # [2.5] +maximum = gpk.max(a) # [4.0] +minimum = gpk.min(a) # [1.0] +max_idx = gpk.argmax(a) # [3] (index of 4.0) # Get scalar value print(total.to_numpy()[0]) # 10.0 + +# Sum along axis +mat = gpk.from_numpy(np.array([[1, 2], [3, 4]], dtype=np.float32)) +row_sum = gpk.sum_axis(mat, axis=1) # [3, 7] +col_sum = gpk.sum_axis(mat, axis=0) # [4, 6] ``` --- @@ -418,6 +548,108 @@ output = gpk.linear_bias_gelu(input, weight, bias) --- +## FP8 Operations (SM120+) + +FP8 E4M3 GEMM operations for Blackwell GPUs (RTX 5090, B100, B200). + +### fp8_fp8_sm120_available + +```python +def fp8_fp8_sm120_available() -> bool: + """Check if FP8 I/O GEMM is available (requires SM120+).""" +``` + +### fp8_fp8_get_scale_sizes + +```python +def fp8_fp8_get_scale_sizes(M: int, N: int, K: int) -> tuple[int, int]: + """Get required scale factor sizes for blockwise FP8 GEMM. + + Args: + M: Number of rows in A + N: Number of columns in B + K: Inner dimension + + Returns: + Tuple of (scale_A_size, scale_B_size) + """ +``` + +### matmul_fp8_fp8_sm120 + +```python +def matmul_fp8_fp8_sm120( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """FP8 E4M3 GEMM with unity scaling. + + Args: + a: FP8 E4M3 matrix [M, K] (stored as uint8) + b: FP8 E4M3 matrix [K, N] (stored as uint8) + out: Optional output buffer [M, N] + + Returns: + FP8 E4M3 result [M, N] (stored as uint8) + """ +``` + +### matmul_fp8_fp8_blockwise_sm120 + +```python +def matmul_fp8_fp8_blockwise_sm120( + a: GPUArray, + b: GPUArray, + scale_a: GPUArray, + scale_b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """FP8 E4M3 GEMM with blockwise scaling. + + For FP8 models (Llama 3.1 FP8, Qwen FP8, etc.) that store + per-block scale factors alongside quantized weights. + + Args: + a: FP8 E4M3 matrix [M, K] (stored as uint8) + b: FP8 E4M3 matrix [K, N] (stored as uint8) + scale_a: Scale factors for A (size from fp8_fp8_get_scale_sizes) + scale_b: Scale factors for B (size from fp8_fp8_get_scale_sizes) + out: Optional output buffer [M, N] + + Returns: + FP8 E4M3 result [M, N] (stored as uint8) + + Note: + Minimum matrix size is 128x128x128 due to CUTLASS tile requirements. + """ +``` + +**Example:** +```python +import pygpukit as gpk +import numpy as np + +if gpk.fp8_fp8_sm120_available(): + M, N, K = 4096, 4096, 4096 + + # Create FP8 data (stored as uint8) + A = gpk.from_numpy(np.random.randint(0, 255, (M, K), dtype=np.uint8)) + B = gpk.from_numpy(np.random.randint(0, 255, (K, N), dtype=np.uint8)) + + # Get scale sizes and create scale factors + sfa_size, sfb_size = gpk.fp8_fp8_get_scale_sizes(M, N, K) + scale_A = gpk.from_numpy(np.ones(sfa_size, dtype=np.float32)) + scale_B = gpk.from_numpy(np.ones(sfb_size, dtype=np.float32)) + + # Blockwise scaled FP8 GEMM + C = gpk.matmul_fp8_fp8_blockwise_sm120(A, B, scale_A, scale_B) +``` + +--- + ## Device Information ### is_cuda_available diff --git a/examples/chat_cli.py b/examples/chat_cli.py index c0498f1..9cd5647 100644 --- a/examples/chat_cli.py +++ b/examples/chat_cli.py @@ -269,6 +269,23 @@ def main(): action="store_true", help="Enable CUDA Graph for faster decode (reduces kernel launch overhead)", ) + parser.add_argument( + "--speculative", + action="store_true", + help="[EXPERIMENTAL] Enable self-speculative decoding (uses argmax, may cause repetition)", + ) + parser.add_argument( + "--draft-tokens", + type=int, + default=4, + help="Number of draft tokens per speculation round (default: 4)", + ) + parser.add_argument( + "--draft-layers", + type=int, + default=8, + help="Number of early layers to use as draft model (default: 8)", + ) args = parser.parse_args() # Lazy imports for faster --help @@ -280,6 +297,7 @@ def main(): ChatMessage, DecodeM1, DecodeM1Graph, + DecodeSpeculative, detect_model_spec, format_chat_messages, load_model_from_safetensors, @@ -332,9 +350,23 @@ def main(): # Initialize decode strategy use_cuda_graph = args.cuda_graph + use_speculative = args.speculative m1_graph = None - - if use_cuda_graph: + speculative_strategy = None + + if use_speculative: + # Use DecodeSpeculative for self-speculative decoding + print("\nInitializing Self-Speculative Decode...") + print(f" draft_tokens={args.draft_tokens}, draft_layers={args.draft_layers}") + print(" WARNING: Uses argmax (greedy) decoding - may produce repetitive output") + print(" For production use, prefer --cuda-graph instead") + speculative_strategy = DecodeSpeculative( + max_draft_tokens=args.draft_tokens, + draft_layers=args.draft_layers, + ) + speculative_strategy.bind(model) + m1 = None # Not used in speculative mode + elif use_cuda_graph: # Use DecodeM1Graph for CUDA Graph mode print("\nInitializing CUDA Graph...") m1_graph = DecodeM1Graph() @@ -729,9 +761,143 @@ def generate_chunked(messages: list[ChatMessage]) -> tuple[str, float, float, in batch_chunks, ) + def generate_speculative( + messages: list[ChatMessage], + ) -> tuple[str, float, float, int, int, float]: + """Generate using self-speculative decoding. + + Uses early layers as draft model, verifies with full model in batch. + Uses KV snapshot/restore for correctness. + + Returns: (text, prefill_time, decode_time, total_tokens, total_drafts, accept_rate) + """ + prompt = format_chat_messages(messages, model_type=model_type) + input_ids = tokenizer.encode(prompt).ids + + if len(input_ids) >= args.max_seq_len - 10: + return "[Error: Conversation too long. Use /clear to reset.]", 0, 0, 0, 0, 0.0 + + # Prefill + t_prefill_start = time.perf_counter() + hidden, past_key_values = model(input_ids, use_cache=True) + for i, block in enumerate(model.blocks): + past_k, past_v = past_key_values[i] + kv_cache_prefill_gqa(past_k, block.attn._k_cache, block.attn.num_heads, start_pos=0) + kv_cache_prefill_gqa(past_v, block.attn._v_cache, block.attn.num_heads, start_pos=0) + default_stream().synchronize() + prefill_time = time.perf_counter() - t_prefill_start + + # Self-speculative decode + t_decode_start = time.perf_counter() + generated_ids: list[int] = [] + stream_decoder = StreamingDecoder(tokenizer) + position = len(input_ids) + context_len = position + 1 + at_start = True + skip_count = 0 + + # Stats + total_drafts = 0 + total_accepted = 0 + + # Get first token from prefill + logits = model.get_logits(hidden) + logits_np = logits_to_f32(logits)[-1] + next_token = sample_token(logits_np, args.temperature, args.top_k, args.top_p) + + # Skip special tokens at start (e.g., <|im_start|>assistant\n) + while should_skip_token(next_token, at_start, skip_count): + if context_len >= args.max_seq_len: + break + # Use fixed cache decode for skipping + hidden = model._decode_step_fixed_cache(next_token, position, context_len) + logits = model.get_logits(hidden) + logits_np = logits_to_f32(logits)[-1] + next_token = sample_token(logits_np, args.temperature, args.top_k, args.top_p) + position += 1 + context_len += 1 + skip_count += 1 + + at_start = False + + # Check if first real token is end token + if is_end_token(next_token): + default_stream().synchronize() + decode_time = time.perf_counter() - t_decode_start + return "", prefill_time, decode_time, 0, 0, 0.0 + + # Output first real token (step_speculative takes this as input and returns NEXT tokens) + text_chunk = stream_decoder.add_token(next_token) + if text_chunk: + print(text_chunk, end="", flush=True) + generated_ids.append(next_token) + + # Main speculative decode loop + while len(generated_ids) < args.max_new_tokens: + if context_len >= args.max_seq_len: + break + + if is_end_token(next_token): + break + + # Run speculative decode step (uses KV snapshot/restore) + accepted_tokens, new_position, stats = speculative_strategy.step_speculative( + next_token, position, context_len + ) + + # Track stats + total_drafts += stats["draft_count"] + total_accepted += stats["accepted_count"] + + # Stream out accepted tokens + for tok in accepted_tokens: + if is_end_token(tok): + break + generated_ids.append(tok) + text_chunk = stream_decoder.add_token(tok) + if text_chunk: + print(text_chunk, end="", flush=True) + + # Check if we hit end token + if any(is_end_token(tok) for tok in accepted_tokens): + break + + # Update position for next iteration + position = new_position + context_len = position + 1 + + # Get next token for next speculation round + if accepted_tokens: + next_token = accepted_tokens[-1] + else: + break + + # Flush any remaining buffered text + remaining = stream_decoder.flush() + if remaining: + print(remaining, end="", flush=True) + + default_stream().synchronize() + decode_time = time.perf_counter() - t_decode_start + + # Calculate acceptance rate + accept_rate = total_accepted / total_drafts if total_drafts > 0 else 0.0 + + print() + return ( + tokenizer.decode(generated_ids), + prefill_time, + decode_time, + len(generated_ids), + total_drafts, + accept_rate, + ) + def generate_response(messages: list[ChatMessage]): """Dispatch to appropriate generation method.""" - if batch_size > 1: + if use_speculative: + return generate_speculative(messages) + elif batch_size > 1: return generate_chunked(messages) else: return generate_m1(messages) @@ -741,7 +907,11 @@ def generate_response(messages: list[ChatMessage]): # ========================================================================= print("\n" + "=" * 60) print(" PyGPUkit Chat") - if batch_size > 1: + if use_speculative: + mode_str = ( + f"Self-Speculative (draft_tokens={args.draft_tokens}, draft_layers={args.draft_layers})" + ) + elif batch_size > 1: mode_str = f"Chunked (chunk_size={batch_size})" elif use_cuda_graph: mode_str = "M=1 + CUDA Graph" @@ -781,14 +951,16 @@ def generate_response(messages: list[ChatMessage]): result = generate_response(messages) - if batch_size > 1: + if use_speculative: + response, prefill_time, decode_time, total_tokens, total_drafts, accept_rate = result + tokens_generated = total_tokens + elif batch_size > 1: response, prefill_time, decode_time, total_tokens, accepted_batches = result tokens_generated = total_tokens else: response, prefill_time, decode_time = result # Use length of encoded response, but fallback to 0 if empty tokens_generated = len(tokenizer.encode(response).ids) if response else 0 - accepted_batches = 0 # Add assistant response to history conversation.append(ChatMessage(role="assistant", content=response)) @@ -799,7 +971,9 @@ def generate_response(messages: list[ChatMessage]): f" [prefill: {prefill_time:.1f}s, " f"decode: {tokens_generated} tok / {decode_time:.1f}s = {decode_tps:.1f} tok/s" ) - if batch_size > 1: + if use_speculative: + stats += f", drafts: {total_drafts}, accept: {accept_rate:.1%}" + elif batch_size > 1: stats += f", chunks: {accepted_batches}" stats += "]" print(stats) diff --git a/examples/haru_Info_04.wav b/examples/haru_Info_04.wav new file mode 100644 index 0000000..0565f5c Binary files /dev/null and b/examples/haru_Info_04.wav differ diff --git a/examples/whisper_realtime_stt.py b/examples/whisper_realtime_stt.py new file mode 100644 index 0000000..139c6bf --- /dev/null +++ b/examples/whisper_realtime_stt.py @@ -0,0 +1,615 @@ +#!/usr/bin/env python3 +"""Real-time Speech-to-Text Demo using Whisper. + +This demo shows how to use PyGPUkit's Whisper implementation for +real-time speech recognition from any PCM audio source. + +Supported input sources: +- Microphone (requires sounddevice) +- PCM file (raw audio) +- WAV file + +Usage: + # From microphone (default) + python whisper_realtime_stt.py + + # From WAV file + python whisper_realtime_stt.py --input audio.wav + + # From raw PCM file (16kHz, mono, float32) + python whisper_realtime_stt.py --input audio.pcm --pcm + + # Specify model + python whisper_realtime_stt.py --model kotoba-tech/kotoba-whisper-v2.0 + + # Adjust chunk size (seconds) + python whisper_realtime_stt.py --chunk-size 5.0 + +Requirements: + pip install sounddevice soundfile numpy +""" + +from __future__ import annotations + +import argparse +import sys +import threading +import time +from collections import deque +from dataclasses import dataclass +from typing import Callable + +import numpy as np + +# Audio constants +SAMPLE_RATE = 16000 # Whisper expects 16kHz +CHANNELS = 1 # Mono + + +@dataclass +class TranscriptionEvent: + """Event for transcription results.""" + + text: str + start_time: float + end_time: float + is_partial: bool = False + + +class AudioBuffer: + """Thread-safe audio buffer for real-time processing.""" + + def __init__(self, chunk_duration: float = 5.0, overlap: float = 0.5): + """Initialize audio buffer. + + Args: + chunk_duration: Duration of each chunk in seconds + overlap: Overlap between chunks in seconds + """ + self.chunk_samples = int(chunk_duration * SAMPLE_RATE) + self.overlap_samples = int(overlap * SAMPLE_RATE) + self.stride_samples = self.chunk_samples - self.overlap_samples + + self._buffer: deque = deque() + self._lock = threading.Lock() + self._total_samples = 0 + + def write(self, audio: np.ndarray) -> None: + """Write audio samples to buffer.""" + with self._lock: + self._buffer.extend(audio.flatten()) + self._total_samples += len(audio.flatten()) + + def read_chunk(self) -> tuple[np.ndarray, float] | None: + """Read a chunk of audio if available. + + Returns: + Tuple of (audio_chunk, start_time) or None if not enough data + """ + with self._lock: + if len(self._buffer) < self.chunk_samples: + return None + + # Extract chunk + chunk = np.array([self._buffer[i] for i in range(self.chunk_samples)]) + + # Calculate start time + consumed = self._total_samples - len(self._buffer) + start_time = consumed / SAMPLE_RATE + + # Remove processed samples (keeping overlap) + for _ in range(self.stride_samples): + if self._buffer: + self._buffer.popleft() + + return chunk.astype(np.float32), start_time + + @property + def buffered_duration(self) -> float: + """Get buffered duration in seconds.""" + with self._lock: + return len(self._buffer) / SAMPLE_RATE + + +class RealtimeSTT: + """Real-time Speech-to-Text engine using Whisper.""" + + def __init__( + self, + model_id: str = "kotoba-tech/kotoba-whisper-v2.0", + chunk_duration: float = 5.0, + language: str | None = None, + on_transcription: Callable[[TranscriptionEvent], None] | None = None, + ): + """Initialize real-time STT. + + Args: + model_id: Whisper model ID or path + chunk_duration: Duration of each chunk in seconds + language: Language code (e.g., "ja", "en") + on_transcription: Callback for transcription events + """ + self.model_id = model_id + self.chunk_duration = chunk_duration + self.language = language + self.on_transcription = on_transcription + + self._model = None + self._buffer = AudioBuffer(chunk_duration=chunk_duration) + self._running = False + self._thread: threading.Thread | None = None + + def load_model(self) -> None: + """Load Whisper model.""" + print(f"Loading model: {self.model_id}...") + from pygpukit.asr import WhisperModel + + self._model = WhisperModel.from_pretrained(self.model_id) + print("Model loaded successfully!") + + def start(self) -> None: + """Start the transcription thread.""" + if self._model is None: + self.load_model() + + self._running = True + self._thread = threading.Thread(target=self._transcription_loop, daemon=True) + self._thread.start() + + def stop(self) -> None: + """Stop the transcription thread.""" + self._running = False + if self._thread: + self._thread.join(timeout=2.0) + + def feed_audio(self, audio: np.ndarray) -> None: + """Feed audio samples to the STT engine. + + Args: + audio: Audio samples (float32, -1.0 to 1.0) + """ + self._buffer.write(audio) + + def _transcription_loop(self) -> None: + """Background loop for processing audio chunks.""" + while self._running: + chunk_data = self._buffer.read_chunk() + + if chunk_data is None: + time.sleep(0.1) + continue + + audio_chunk, start_time = chunk_data + + try: + # Transcribe chunk + result = self._model.transcribe( + audio_chunk, + language=self.language, + temperature=0.0, + ) + + # Create event + event = TranscriptionEvent( + text=result.text.strip(), + start_time=start_time, + end_time=start_time + len(audio_chunk) / SAMPLE_RATE, + ) + + # Callback + if self.on_transcription and event.text: + self.on_transcription(event) + + except Exception as e: + print(f"Transcription error: {e}", file=sys.stderr) + + +def read_pcm_file(path: str, sample_rate: int = SAMPLE_RATE) -> np.ndarray: + """Read raw PCM file. + + Args: + path: Path to PCM file + sample_rate: Expected sample rate + + Returns: + Audio array (float32) + """ + # Try to read as float32 first, then int16 + try: + audio = np.fromfile(path, dtype=np.float32) + if np.abs(audio).max() > 10: # Probably int16 + raise ValueError("Not float32") + except (ValueError, Exception): + audio = np.fromfile(path, dtype=np.int16).astype(np.float32) / 32768.0 + + return audio + + +def read_wav_file(path: str) -> tuple[np.ndarray, int]: + """Read WAV file. + + Args: + path: Path to WAV file + + Returns: + Tuple of (audio, sample_rate) + """ + try: + import soundfile as sf + + audio, sr = sf.read(path) + if audio.ndim > 1: + audio = audio.mean(axis=1) + return audio.astype(np.float32), sr + except ImportError as err: + raise ImportError("soundfile is required: pip install soundfile") from err + + +def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: + """Resample audio to target sample rate. + + Args: + audio: Input audio + orig_sr: Original sample rate + target_sr: Target sample rate + + Returns: + Resampled audio + """ + if orig_sr == target_sr: + return audio + + try: + import resampy + + return resampy.resample(audio, orig_sr, target_sr) + except ImportError: + # Simple linear interpolation fallback + duration = len(audio) / orig_sr + target_len = int(duration * target_sr) + indices = np.linspace(0, len(audio) - 1, target_len) + return np.interp(indices, np.arange(len(audio)), audio).astype(np.float32) + + +class MicrophoneStream: + """Microphone audio stream.""" + + def __init__( + self, + sample_rate: int = SAMPLE_RATE, + chunk_size: int = 1024, + device: int | None = None, + ): + self.sample_rate = sample_rate + self.chunk_size = chunk_size + self.device = device + self._stream = None + + def start(self, callback: Callable[[np.ndarray], None]) -> None: + """Start microphone stream. + + Args: + callback: Function to call with audio chunks + """ + try: + import sounddevice as sd + except ImportError as err: + raise ImportError( + "sounddevice is required for microphone: pip install sounddevice" + ) from err + + def audio_callback(indata, frames, time_info, status): + if status: + print(f"Audio status: {status}", file=sys.stderr) + callback(indata.copy()) + + self._stream = sd.InputStream( + samplerate=self.sample_rate, + channels=CHANNELS, + dtype=np.float32, + blocksize=self.chunk_size, + device=self.device, + callback=audio_callback, + ) + self._stream.start() + + def stop(self) -> None: + """Stop microphone stream.""" + if self._stream: + self._stream.stop() + self._stream.close() + + +def print_transcription(event: TranscriptionEvent) -> None: + """Print transcription event to console.""" + timestamp = f"[{event.start_time:6.1f}s - {event.end_time:6.1f}s]" + print(f"{timestamp} {event.text}") + + +def list_audio_devices() -> list[dict]: + """List available audio input devices. + + Returns: + List of device info dicts with 'index', 'name', 'channels', 'sample_rate' + """ + try: + import sounddevice as sd + except ImportError as err: + raise ImportError("sounddevice is required: pip install sounddevice") from err + + devices = [] + for i, dev in enumerate(sd.query_devices()): + if dev["max_input_channels"] > 0: # Input device + devices.append( + { + "index": i, + "name": dev["name"], + "channels": dev["max_input_channels"], + "sample_rate": dev["default_samplerate"], + } + ) + return devices + + +def print_audio_devices() -> None: + """Print available audio input devices.""" + devices = list_audio_devices() + print("\nAvailable audio input devices:") + print("-" * 60) + for dev in devices: + print(f" [{dev['index']:2d}] {dev['name']}") + print(f" Channels: {dev['channels']}, Sample Rate: {dev['sample_rate']:.0f} Hz") + print("-" * 60) + + +def select_audio_device() -> int | None: + """Interactively select an audio input device. + + Returns: + Selected device index or None for default + """ + devices = list_audio_devices() + + if not devices: + print("No audio input devices found!") + return None + + if len(devices) == 1: + print(f"Using audio device: {devices[0]['name']}") + return devices[0]["index"] + + print("\nAvailable audio input devices:") + print("-" * 60) + for dev in devices: + print(f" [{dev['index']:2d}] {dev['name']}") + print("-" * 60) + + while True: + try: + choice = input( + f"Select device [0-{max(d['index'] for d in devices)}, Enter=default]: " + ).strip() + if choice == "": + return None + idx = int(choice) + if any(d["index"] == idx for d in devices): + return idx + print(f"Invalid device index: {idx}") + except ValueError: + print("Please enter a valid number") + except KeyboardInterrupt: + print("\nCancelled") + sys.exit(0) + + +def demo_microphone(args: argparse.Namespace) -> None: + """Run demo with microphone input.""" + # Select device if not specified + device = args.device + if device is None and args.select_device: + device = select_audio_device() + + print("=" * 60) + print("Real-time Speech-to-Text Demo (Microphone)") + print("=" * 60) + print(f"Model: {args.model}") + print(f"Language: {args.language or 'auto'}") + print(f"Chunk size: {args.chunk_size}s") + if device is not None: + print(f"Device: {device}") + print("-" * 60) + print("Speak into your microphone. Press Ctrl+C to stop.") + print("-" * 60) + + # Initialize STT + stt = RealtimeSTT( + model_id=args.model, + chunk_duration=args.chunk_size, + language=args.language, + on_transcription=print_transcription, + ) + stt.load_model() + + # Start microphone + mic = MicrophoneStream(device=device) + + try: + stt.start() + mic.start(stt.feed_audio) + + # Keep running until Ctrl+C + while True: + time.sleep(0.1) + + except KeyboardInterrupt: + print("\nStopping...") + finally: + mic.stop() + stt.stop() + + +def demo_file(args: argparse.Namespace) -> None: + """Run demo with file input.""" + print("=" * 60) + print("Real-time Speech-to-Text Demo (File)") + print("=" * 60) + print(f"Model: {args.model}") + print(f"Input: {args.input}") + print(f"Language: {args.language or 'auto'}") + print(f"Chunk size: {args.chunk_size}s") + print("-" * 60) + + # Load audio + if args.pcm: + print("Loading PCM file...") + audio = read_pcm_file(args.input) + sr = args.sample_rate + else: + print("Loading audio file...") + audio, sr = read_wav_file(args.input) + + # Resample if needed + if sr != SAMPLE_RATE: + print(f"Resampling from {sr}Hz to {SAMPLE_RATE}Hz...") + audio = resample_audio(audio, sr, SAMPLE_RATE) + + print(f"Audio duration: {len(audio) / SAMPLE_RATE:.1f}s") + print("-" * 60) + + # Initialize STT + stt = RealtimeSTT( + model_id=args.model, + chunk_duration=args.chunk_size, + language=args.language, + on_transcription=print_transcription, + ) + stt.load_model() + + # Process audio in real-time simulation + stt.start() + + # Feed audio in chunks (simulating real-time) + chunk_samples = int(0.1 * SAMPLE_RATE) # 100ms chunks + try: + for i in range(0, len(audio), chunk_samples): + chunk = audio[i : i + chunk_samples] + stt.feed_audio(chunk) + + # Simulate real-time by sleeping + if not args.fast: + time.sleep(len(chunk) / SAMPLE_RATE) + + # Wait for processing to complete + print("\nProcessing remaining audio...") + time.sleep(args.chunk_size + 1) + + except KeyboardInterrupt: + print("\nStopping...") + finally: + stt.stop() + + +def main(): + parser = argparse.ArgumentParser( + description="Real-time Speech-to-Text Demo using Whisper", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # List available microphones + python whisper_realtime_stt.py --list-devices + + # Select microphone interactively + python whisper_realtime_stt.py --select-device + + # Use specific microphone by index + python whisper_realtime_stt.py --device 2 + + # WAV file input + python whisper_realtime_stt.py --input recording.wav + + # Raw PCM file (16kHz, mono, float32) + python whisper_realtime_stt.py --input audio.pcm --pcm + + # Japanese model with 3-second chunks + python whisper_realtime_stt.py --model kotoba-tech/kotoba-whisper-v2.0 \\ + --language ja --chunk-size 3.0 +""", + ) + + parser.add_argument( + "--input", + "-i", + type=str, + default=None, + help="Input audio file (WAV or PCM). If not specified, uses microphone.", + ) + parser.add_argument( + "--pcm", + action="store_true", + help="Treat input as raw PCM file", + ) + parser.add_argument( + "--sample-rate", + type=int, + default=SAMPLE_RATE, + help=f"Sample rate for PCM input (default: {SAMPLE_RATE})", + ) + parser.add_argument( + "--model", + "-m", + type=str, + default="kotoba-tech/kotoba-whisper-v2.0", + help="Whisper model ID or path", + ) + parser.add_argument( + "--language", + "-l", + type=str, + default=None, + help="Language code (e.g., 'ja', 'en'). Auto-detect if not specified.", + ) + parser.add_argument( + "--chunk-size", + type=float, + default=5.0, + help="Chunk duration in seconds (default: 5.0)", + ) + parser.add_argument( + "--device", + "-d", + type=int, + default=None, + help="Audio input device index (for microphone)", + ) + parser.add_argument( + "--list-devices", + action="store_true", + help="List available audio input devices and exit", + ) + parser.add_argument( + "--select-device", + "-s", + action="store_true", + help="Interactively select audio input device at startup", + ) + parser.add_argument( + "--fast", + action="store_true", + help="Process file as fast as possible (no real-time simulation)", + ) + + args = parser.parse_args() + + # List devices and exit + if args.list_devices: + print_audio_devices() + return + + if args.input: + demo_file(args) + else: + demo_microphone(args) + + +if __name__ == "__main__": + main() diff --git a/native/CMakeLists.txt b/native/CMakeLists.txt index 627ea89..2687f53 100644 --- a/native/CMakeLists.txt +++ b/native/CMakeLists.txt @@ -33,11 +33,24 @@ include_directories(${CUDAToolkit_INCLUDE_DIRS}) # CUTLASS (header-only library) # Can be disabled via environment variable PYGPUKIT_DISABLE_CUTLASS=1 -set(CUTLASS_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../third_party/cutlass") +# Try multiple paths for CUTLASS (scikit-build-core may change working directory) +set(CUTLASS_DIR_CANDIDATES + "${CMAKE_CURRENT_SOURCE_DIR}/../third_party/cutlass" + "${CMAKE_CURRENT_LIST_DIR}/../third_party/cutlass" + "${CMAKE_SOURCE_DIR}/../third_party/cutlass" +) +set(CUTLASS_FOUND FALSE) +foreach(CUTLASS_CANDIDATE ${CUTLASS_DIR_CANDIDATES}) + if(EXISTS "${CUTLASS_CANDIDATE}/include" AND NOT CUTLASS_FOUND) + set(CUTLASS_DIR "${CUTLASS_CANDIDATE}") + set(CUTLASS_FOUND TRUE) + endif() +endforeach() + if(DEFINED ENV{PYGPUKIT_DISABLE_CUTLASS}) message(STATUS "CUTLASS disabled via PYGPUKIT_DISABLE_CUTLASS environment variable") add_definitions(-DPYGPUKIT_HAS_CUTLASS=0) -elseif(EXISTS "${CUTLASS_DIR}/include") +elseif(CUTLASS_FOUND) message(STATUS "CUTLASS found at: ${CUTLASS_DIR}") include_directories(${CUTLASS_DIR}/include) include_directories(${CUTLASS_DIR}/tools/util/include) @@ -46,7 +59,8 @@ elseif(EXISTS "${CUTLASS_DIR}/include") # Disabled for now - will be enabled when SM90+ testing is available # add_definitions(-DCUTLASS_ARCH_MMA_SM90_SUPPORTED=1) else() - message(STATUS "CUTLASS not found, using fallback kernels") + message(STATUS "CUTLASS not found at any of: ${CUTLASS_DIR_CANDIDATES}") + message(STATUS "Using fallback kernels") add_definitions(-DPYGPUKIT_HAS_CUTLASS=0) endif() @@ -69,6 +83,39 @@ endif() message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") +# Enable CUTLASS SM support based on target architectures +# _SUPPORTED macros enable host-side type definitions +# _ENABLED macros are auto-defined by CUTLASS based on __CUDA_ARCH__ during device compilation +string(FIND "${CMAKE_CUDA_ARCHITECTURES}" "90" SM90_POS) +string(FIND "${CMAKE_CUDA_ARCHITECTURES}" "100" SM100_POS) +string(FIND "${CMAKE_CUDA_ARCHITECTURES}" "120" SM120_POS) + +# SM90 (Hopper) - FP8 GEMM with per-tensor scaling +# Also enable for SM100+ since they are forward compatible +if(NOT SM90_POS EQUAL -1 OR NOT SM100_POS EQUAL -1 OR NOT SM120_POS EQUAL -1) + message(STATUS "Enabling CUTLASS SM90 (Hopper) support") + add_definitions(-DCUTLASS_ARCH_MMA_SM90_SUPPORTED=1) +endif() + +# SM100 (Blackwell datacenter) +# Also enable for SM120 since they are both Blackwell architecture +if(NOT SM100_POS EQUAL -1 OR NOT SM120_POS EQUAL -1) + message(STATUS "Enabling CUTLASS SM100 (Blackwell datacenter) support") + add_definitions(-DCUTLASS_ARCH_MMA_SM100_SUPPORTED=1) +endif() + +# SM120 (Blackwell GeForce) - FP8 GEMM with blockwise scaling +# Note: Use 120a for full accelerated features (tensor cores, block-scaled MMA) +if(NOT SM120_POS EQUAL -1) + message(STATUS "Enabling CUTLASS SM120 (Blackwell GeForce) support") + add_definitions(-DCUTLASS_ARCH_MMA_SM120_SUPPORTED=1) + # For SM120a (full accelerated features), also enable feature macros + string(FIND "${CMAKE_CUDA_ARCHITECTURES}" "120a" SM120A_POS) + if(NOT SM120A_POS EQUAL -1) + message(STATUS " SM120a: Full accelerated features enabled") + endif() +endif() + # Ampere-optimized compiler flags # Add -v for verbose ptxas output to check register usage # NOTE: Do NOT use -maxrregcount for CUTLASS - it needs many registers for optimal performance @@ -106,6 +153,13 @@ pybind11_add_module(${MODULE_NAME} ops/reduction/reduction.cu ops/matmul/matmul.cu ops/matmul/matmul_cutlass.cu + ops/matmul/matmul_fp8_sm90.cu + ops/matmul/matmul_fp8_sm100.cu + ops/matmul/matmul_fp8_fp32_sm120.cu + ops/matmul/matmul_fp8_fp8_sm120.cu + ops/matmul/matmul_nvf4_bf16_sm120.cu + ops/matmul/matmul_nvf4_nvf4_sm120.cu + ops/gemv/gemv_nvf4.cu ops/nn/nn.cu ops/quantize/quantize.cu ops/attention/paged_attention.cu diff --git a/native/bindings/core_bindings.cpp b/native/bindings/core_bindings.cpp index b5361e7..de57203 100644 --- a/native/bindings/core_bindings.cpp +++ b/native/bindings/core_bindings.cpp @@ -189,12 +189,21 @@ void init_core_bindings(py::module_& m) { dtype = DataType::Int32; } else if (itemsize == 2) { dtype = DataType::Int16; + } else if (itemsize == 1) { + dtype = DataType::Int8; } else { throw std::runtime_error("Unsupported int dtype size: " + std::to_string(itemsize)); } - } else if (kind == 'u' && itemsize == 2) { - // uint16 can be used for bfloat16 storage - dtype = DataType::BFloat16; + } else if (kind == 'u') { + // Unsigned integer types + if (itemsize == 1) { + dtype = DataType::UInt8; + } else if (itemsize == 2) { + // uint16 can be used for bfloat16 storage + dtype = DataType::BFloat16; + } else { + throw std::runtime_error("Unsupported uint dtype size: " + std::to_string(itemsize)); + } } else { throw std::runtime_error("Unsupported numpy dtype"); } diff --git a/native/bindings/ops_bindings.cpp b/native/bindings/ops_bindings.cpp index 88d8400..186dfd3 100644 --- a/native/bindings/ops_bindings.cpp +++ b/native/bindings/ops_bindings.cpp @@ -8,6 +8,92 @@ namespace py = pybind11; using namespace pygpukit; +// Extern declarations for FP8 functions (must be at global scope) +extern "C" { + // SM90 (Hopper) - FP8 with per-tensor scaling + cudaError_t pygpukit_gemm_fp8_sm90( + const float* A, const float* B, float* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ); + bool pygpukit_fp8_sm90_available(); + + // SM100 (Blackwell datacenter) - FP8 with blockwise scaling + cudaError_t pygpukit_gemm_fp8_sm100( + const float* A, const float* B, float* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ); + bool pygpukit_fp8_sm100_available(); + + // SM120 (Blackwell GeForce) - FP8 with blockwise scaling (disabled due to CUTLASS bug #2902) + cudaError_t pygpukit_gemm_fp8_sm120( + const float* A, const float* B, float* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ); + bool pygpukit_fp8_sm120_available(); + + // SM120 (Blackwell GeForce) - Pure FP8 I/O GEMM + cudaError_t pygpukit_gemm_fp8_fp8_sm120( + const uint8_t* A, const uint8_t* B, uint8_t* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ); + bool pygpukit_fp8_fp8_sm120_available(); + + // SM120 (Blackwell GeForce) - Pure FP8 I/O GEMM with blockwise scaling + cudaError_t pygpukit_gemm_fp8_fp8_blockwise_sm120( + const uint8_t* A, const uint8_t* B, uint8_t* D, + const float* scale_A, const float* scale_B, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ); + void pygpukit_fp8_fp8_get_scale_sizes( + int M, int N, int K, + size_t* sfa_size, size_t* sfb_size + ); + + // SM120 (Blackwell GeForce) - NVF4 (4-bit) with BF16 I/O + cudaError_t pygpukit_gemm_nvf4_bf16_sm120( + const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ); + bool pygpukit_nvf4_bf16_sm120_available(); + + // SM120 (Blackwell GeForce) - Pure NVF4 GEMM (for benchmarking) + cudaError_t pygpukit_benchmark_gemm_nvf4_sm120( + __nv_bfloat16* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ); + bool pygpukit_nvf4_nvf4_sm120_available(); + + // NVF4 GEMV for SM120 + bool pygpukit_gemv_nvf4_available(); + cudaError_t pygpukit_quantize_bf16_to_nvf4( + const void* input, void* out_data, void* out_scale, + int K, int N, cudaStream_t stream + ); + cudaError_t pygpukit_gemv_nvf4_bf16( + const void* A, const void* B_data, const void* B_scale, void* C, + int K, int N, float alpha, cudaStream_t stream + ); + cudaError_t pygpukit_gemv_bf16( + const void* A, const void* B, void* C, + int K, int N, float alpha, float beta, cudaStream_t stream + ); + void pygpukit_nvf4_get_sizes(int K, int N, size_t* data_size, size_t* scale_size); +} + void init_ops_bindings(py::module_& m) { // ======================================================================== // Binary Element-wise operations @@ -80,6 +166,78 @@ void init_ops_bindings(py::module_& m) { py::arg("a"), py::arg("out"), "Element-wise ReLU with output array"); + // Sin + m.def("sin", py::overload_cast(&ops::sin), + py::arg("a"), + "Element-wise sine"); + + m.def("sin_", py::overload_cast(&ops::sin), + py::arg("a"), py::arg("out"), + "Element-wise sine with output array"); + + // Cos + m.def("cos", py::overload_cast(&ops::cos), + py::arg("a"), + "Element-wise cosine"); + + m.def("cos_", py::overload_cast(&ops::cos), + py::arg("a"), py::arg("out"), + "Element-wise cosine with output array"); + + // Sqrt + m.def("sqrt", py::overload_cast(&ops::sqrt), + py::arg("a"), + "Element-wise square root"); + + m.def("sqrt_", py::overload_cast(&ops::sqrt), + py::arg("a"), py::arg("out"), + "Element-wise square root with output array"); + + // Rsqrt + m.def("rsqrt", py::overload_cast(&ops::rsqrt), + py::arg("a"), + "Element-wise reciprocal square root: 1/sqrt(x)"); + + m.def("rsqrt_", py::overload_cast(&ops::rsqrt), + py::arg("a"), py::arg("out"), + "Element-wise reciprocal square root with output array"); + + // Abs + m.def("abs", py::overload_cast(&ops::abs), + py::arg("a"), + "Element-wise absolute value"); + + m.def("abs_", py::overload_cast(&ops::abs), + py::arg("a"), py::arg("out"), + "Element-wise absolute value with output array"); + + // Neg + m.def("neg", py::overload_cast(&ops::neg), + py::arg("a"), + "Element-wise negation: -x"); + + m.def("neg_", py::overload_cast(&ops::neg), + py::arg("a"), py::arg("out"), + "Element-wise negation with output array"); + + // Clamp + m.def("clamp", py::overload_cast(&ops::clamp), + py::arg("a"), py::arg("min_val"), py::arg("max_val"), + "Element-wise clamp: clamp(x, min, max)"); + + m.def("clamp_", py::overload_cast(&ops::clamp), + py::arg("a"), py::arg("out"), py::arg("min_val"), py::arg("max_val"), + "Element-wise clamp with output array"); + + // Where (conditional select) + m.def("where", py::overload_cast(&ops::where), + py::arg("cond"), py::arg("a"), py::arg("b"), + "Conditional select: where(cond, a, b) = cond ? a : b"); + + m.def("where_", py::overload_cast(&ops::where), + py::arg("cond"), py::arg("a"), py::arg("b"), py::arg("out"), + "Conditional select with output array"); + // ======================================================================== // Matrix operations // ======================================================================== @@ -117,6 +275,19 @@ void init_ops_bindings(py::module_& m) { py::arg("a"), "Max of all elements (float32/float64 only), returns scalar GPUArray"); + m.def("min", &ops::min, + py::arg("a"), + "Min of all elements, returns scalar GPUArray"); + + m.def("argmax", &ops::argmax, + py::arg("a"), + "Index of maximum element, returns int64 GPUArray"); + + m.def("sum_axis", &ops::sum_axis, + py::arg("a"), py::arg("axis"), + "Sum along specified axis (0 or 1) for 2D tensors.\n" + "axis=0: sum rows -> [N], axis=1: sum columns -> [M]"); + // ======================================================================== // Neural Network operations // ======================================================================== @@ -184,6 +355,24 @@ void init_ops_bindings(py::module_& m) { py::arg("input"), py::arg("out"), "SiLU with output buffer (for CUDA Graph capture)"); + // Sigmoid activation + m.def("sigmoid", py::overload_cast(&ops::sigmoid), + py::arg("input"), + "Sigmoid activation: y = 1 / (1 + exp(-x))"); + + m.def("sigmoid_", py::overload_cast(&ops::sigmoid), + py::arg("input"), py::arg("out"), + "Sigmoid with output buffer (for CUDA Graph capture)"); + + // Tanh activation + m.def("tanh", py::overload_cast(&ops::tanh), + py::arg("input"), + "Tanh activation"); + + m.def("tanh_", py::overload_cast(&ops::tanh), + py::arg("input"), py::arg("out"), + "Tanh with output buffer (for CUDA Graph capture)"); + // RoPE (Rotary Position Embedding) - In-place m.def("rope_inplace", &ops::rope_inplace, py::arg("q"), py::arg("k"), py::arg("cos"), py::arg("sin"), @@ -263,6 +452,36 @@ void init_ops_bindings(py::module_& m) { py::arg("input"), py::arg("out"), "Transpose 3D tensor with output buffer (for CUDA Graph capture)"); + // Transpose 4D: [d0, d1, d2, d3] -> [d0, d2, d1, d3] + m.def("transpose_4d_0213", py::overload_cast(&ops::transpose_4d_0213), + py::arg("input"), + "Transpose 4D tensor: [d0, d1, d2, d3] -> [d0, d2, d1, d3] (swap axes 1 and 2)"); + + // Transpose 4D with output buffer (for CUDA Graph capture) + m.def("transpose_4d_0213_", py::overload_cast(&ops::transpose_4d_0213), + py::arg("input"), py::arg("out"), + "Transpose 4D tensor with output buffer (for CUDA Graph capture)"); + + // Transpose 3D: [d0, d1, d2] -> [d0, d2, d1] (swap last two axes) + m.def("transpose_3d_012", py::overload_cast(&ops::transpose_3d_012), + py::arg("input"), + "Transpose 3D tensor: [d0, d1, d2] -> [d0, d2, d1] (swap last two axes)"); + + // Transpose 3D with output buffer (for CUDA Graph capture) + m.def("transpose_3d_012_", py::overload_cast(&ops::transpose_3d_012), + py::arg("input"), py::arg("out"), + "Transpose 3D tensor with output buffer (for CUDA Graph capture)"); + + // Transpose 4D: [d0, d1, d2, d3] -> [d0, d1, d3, d2] (swap last two axes) + m.def("transpose_4d_0132", py::overload_cast(&ops::transpose_4d_0132), + py::arg("input"), + "Transpose 4D tensor: [d0, d1, d2, d3] -> [d0, d1, d3, d2] (swap last two axes)"); + + // Transpose 4D with output buffer (for CUDA Graph capture) + m.def("transpose_4d_0132_", py::overload_cast(&ops::transpose_4d_0132), + py::arg("input"), py::arg("out"), + "Transpose 4D tensor with output buffer (for CUDA Graph capture)"); + // Reshape with copy m.def("reshape_copy", py::overload_cast&>(&ops::reshape_copy), py::arg("input"), py::arg("new_shape"), @@ -1087,4 +1306,468 @@ void init_ops_bindings(py::module_& m) { auto handle = cublaslt::get_handle(); return reinterpret_cast(handle); }, "Get cuBLASLt handle address for debugging (0 if not available)."); + + // ======================================================================== + // Strided Batched GEMM (for batched matmul in attention) + // ======================================================================== + + m.def("gemm_strided_batched_fp32", &ops::batched_matmul_fp32, + py::arg("A"), py::arg("B"), py::arg("C"), + py::arg("M"), py::arg("N"), py::arg("K"), py::arg("batch_count"), + py::arg("strideA"), py::arg("strideB"), py::arg("strideC"), + "Strided batched GEMM: C[b] = A[b] @ B[b] for b in [0, batch_count)"); + + // ======================================================================== + // FP8 GEMM for SM90 (Hopper) - per-tensor scaling + // ======================================================================== + + m.def("fp8_sm90_available", []() { + return pygpukit_fp8_sm90_available(); + }, "Check if FP8 GEMM is available on SM90 (Hopper)"); + + m.def("gemm_fp8_sm90", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { + if (A.dtype() != DataType::Float32 || B.dtype() != DataType::Float32 || D.dtype() != DataType::Float32) { + throw std::runtime_error("gemm_fp8_sm90: all inputs must be float32"); + } + if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_fp8_sm90: all inputs must be 2D"); + } + + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B.shape()[1]; + + if (B.shape()[0] != static_cast(K)) { + throw std::runtime_error("gemm_fp8_sm90: A.shape[1] must equal B.shape[0]"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_fp8_sm90: D shape mismatch"); + } + + cudaError_t err = pygpukit_gemm_fp8_sm90( + static_cast(A.data()), + static_cast(B.data()), + static_cast(D.data()), + M, N, K, + 1.0f, 0.0f, + nullptr + ); + + if (err != cudaSuccess) { + throw std::runtime_error("gemm_fp8_sm90 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B"), py::arg("D"), + "FP8 GEMM for SM90 (Hopper): D = A @ B (with FP8 quantization internally)"); + + // ======================================================================== + // FP8 GEMM for SM100 (Blackwell datacenter) - blockwise scaling + // Potential fallback for SM120 (same Blackwell architecture) + // ======================================================================== + + m.def("fp8_sm100_available", []() { + return pygpukit_fp8_sm100_available(); + }, "Check if FP8 GEMM is available on SM100 (Blackwell datacenter)"); + + m.def("gemm_fp8_sm100", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { + if (A.dtype() != DataType::Float32 || B.dtype() != DataType::Float32 || D.dtype() != DataType::Float32) { + throw std::runtime_error("gemm_fp8_sm100: all inputs must be float32"); + } + if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_fp8_sm100: all inputs must be 2D"); + } + + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B.shape()[1]; + + if (B.shape()[0] != static_cast(K)) { + throw std::runtime_error("gemm_fp8_sm100: A.shape[1] must equal B.shape[0]"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_fp8_sm100: D shape mismatch"); + } + + cudaError_t err = pygpukit_gemm_fp8_sm100( + static_cast(A.data()), + static_cast(B.data()), + static_cast(D.data()), + M, N, K, + 1.0f, 0.0f, + nullptr + ); + + if (err != cudaSuccess) { + throw std::runtime_error("gemm_fp8_sm100 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B"), py::arg("D"), + "FP8 GEMM for SM100 (Blackwell datacenter): D = A @ B (with FP8 quantization internally)"); + + // ======================================================================== + // FP8 GEMM for SM120 (Blackwell GeForce) - blockwise scaling + // NOTE: Currently disabled due to CUTLASS bug #2902 + // ======================================================================== + + m.def("fp8_sm120_available", []() { + return pygpukit_fp8_sm120_available(); + }, "Check if FP8 GEMM is available on SM120 (currently disabled due to CUTLASS bug)"); + + m.def("gemm_fp8_sm120", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { + if (A.dtype() != DataType::Float32 || B.dtype() != DataType::Float32 || D.dtype() != DataType::Float32) { + throw std::runtime_error("gemm_fp8_sm120: all inputs must be float32"); + } + if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_fp8_sm120: all inputs must be 2D"); + } + + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B.shape()[1]; + + if (B.shape()[0] != static_cast(K)) { + throw std::runtime_error("gemm_fp8_sm120: A.shape[1] must equal B.shape[0]"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_fp8_sm120: D shape mismatch"); + } + + cudaError_t err = pygpukit_gemm_fp8_sm120( + static_cast(A.data()), + static_cast(B.data()), + static_cast(D.data()), + M, N, K, + 1.0f, 0.0f, + nullptr + ); + + if (err != cudaSuccess) { + throw std::runtime_error("gemm_fp8_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B"), py::arg("D"), + "FP8 GEMM for SM120: D = A @ B (with FP8 quantization internally)"); + + // ======================================================================== + // Pure FP8 I/O GEMM for SM120 (FP8 models) + // ======================================================================== + + m.def("fp8_fp8_sm120_available", []() { + return pygpukit_fp8_fp8_sm120_available(); + }, "Check if Pure FP8 I/O GEMM is available on SM120"); + + m.def("gemm_fp8_fp8_sm120", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { + // FP8 is stored as UInt8 in GPUArray + if (A.dtype() != DataType::UInt8 || B.dtype() != DataType::UInt8 || D.dtype() != DataType::UInt8) { + throw std::runtime_error("gemm_fp8_fp8_sm120: all inputs must be uint8 (FP8 E4M3)"); + } + if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_fp8_fp8_sm120: all inputs must be 2D"); + } + + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B.shape()[1]; + + // B is expected to be in ColumnMajor format [K, N] stored as [N, K] transposed + if (B.shape()[0] != static_cast(K)) { + throw std::runtime_error("gemm_fp8_fp8_sm120: A.shape[1] must equal B.shape[0]"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_fp8_fp8_sm120: D shape mismatch"); + } + + cudaError_t err = pygpukit_gemm_fp8_fp8_sm120( + static_cast(A.data()), + static_cast(B.data()), + static_cast(D.data()), + M, N, K, + 1.0f, 0.0f, + nullptr + ); + + if (err != cudaSuccess) { + throw std::runtime_error("gemm_fp8_fp8_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B"), py::arg("D"), + "Pure FP8 I/O GEMM for SM120: D = A @ B (FP8 E4M3 input/output)"); + + // Blockwise scaled FP8 GEMM + m.def("gemm_fp8_fp8_blockwise_sm120", []( + const GPUArray& A, const GPUArray& B, GPUArray& D, + const GPUArray& scale_A, const GPUArray& scale_B + ) { + // FP8 is stored as UInt8 in GPUArray + if (A.dtype() != DataType::UInt8 || B.dtype() != DataType::UInt8 || D.dtype() != DataType::UInt8) { + throw std::runtime_error("gemm_fp8_fp8_blockwise_sm120: A, B, D must be uint8 (FP8 E4M3)"); + } + if (scale_A.dtype() != DataType::Float32 || scale_B.dtype() != DataType::Float32) { + throw std::runtime_error("gemm_fp8_fp8_blockwise_sm120: scale_A, scale_B must be float32"); + } + if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_fp8_fp8_blockwise_sm120: A, B, D must be 2D"); + } + + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B.shape()[1]; + + if (B.shape()[0] != static_cast(K)) { + throw std::runtime_error("gemm_fp8_fp8_blockwise_sm120: A.shape[1] must equal B.shape[0]"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_fp8_fp8_blockwise_sm120: D shape mismatch"); + } + + cudaError_t err = pygpukit_gemm_fp8_fp8_blockwise_sm120( + static_cast(A.data()), + static_cast(B.data()), + static_cast(D.data()), + static_cast(scale_A.data()), + static_cast(scale_B.data()), + M, N, K, + 1.0f, 0.0f, + nullptr + ); + + if (err != cudaSuccess) { + throw std::runtime_error("gemm_fp8_fp8_blockwise_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B"), py::arg("D"), py::arg("scale_A"), py::arg("scale_B"), + "Blockwise scaled FP8 I/O GEMM for SM120: D = (A * scale_A) @ (B * scale_B)"); + + // Get scale factor sizes for FP8 blockwise GEMM + m.def("fp8_fp8_get_scale_sizes", [](int M, int N, int K) { + size_t sfa_size, sfb_size; + pygpukit_fp8_fp8_get_scale_sizes(M, N, K, &sfa_size, &sfb_size); + return py::make_tuple(sfa_size, sfb_size); + }, py::arg("M"), py::arg("N"), py::arg("K"), + "Get scale factor sizes for FP8 blockwise GEMM (returns (sfa_size, sfb_size))"); + + // ======================================================================== + // NVF4 (4-bit) GEMM for SM120 with BF16 I/O + // ======================================================================== + + m.def("nvf4_bf16_sm120_available", []() { + return pygpukit_nvf4_bf16_sm120_available(); + }, "Check if NVF4 BF16 GEMM is available on SM120"); + + m.def("gemm_nvf4_bf16_sm120", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { + if (A.dtype() != DataType::BFloat16 || B.dtype() != DataType::BFloat16 || D.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemm_nvf4_bf16_sm120: all inputs must be bfloat16"); + } + if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_nvf4_bf16_sm120: all inputs must be 2D"); + } + + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B.shape()[1]; + + if (B.shape()[0] != static_cast(K)) { + throw std::runtime_error("gemm_nvf4_bf16_sm120: A.shape[1] must equal B.shape[0]"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_nvf4_bf16_sm120: D shape mismatch"); + } + + cudaError_t err = pygpukit_gemm_nvf4_bf16_sm120( + static_cast(A.data()), + static_cast(B.data()), + static_cast<__nv_bfloat16*>(D.data()), + M, N, K, + 1.0f, 0.0f, + nullptr + ); + + if (err != cudaSuccess) { + throw std::runtime_error("gemm_nvf4_bf16_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B"), py::arg("D"), + "NVF4 (4-bit) GEMM for SM120 with BF16 I/O: D = A @ B (BF16 -> NVF4 quantize -> GEMM -> BF16)"); + + m.def("nvf4_nvf4_sm120_available", []() { + return pygpukit_nvf4_nvf4_sm120_available(); + }, "Check if pure NVF4 GEMM is available (SM120+)"); + + m.def("benchmark_gemm_nvf4_sm120", [](GPUArray& D, int M, int N, int K) { + if (D.dtype() != DataType::BFloat16) { + throw std::runtime_error("benchmark_gemm_nvf4_sm120: D must be bfloat16"); + } + if (D.ndim() != 2) { + throw std::runtime_error("benchmark_gemm_nvf4_sm120: D must be 2D"); + } + + cudaError_t err = pygpukit_benchmark_gemm_nvf4_sm120( + static_cast<__nv_bfloat16*>(D.data()), + M, N, K, + 1.0f, 0.0f, + nullptr + ); + + if (err != cudaSuccess) { + throw std::runtime_error("benchmark_gemm_nvf4_sm120 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("D"), py::arg("M"), py::arg("N"), py::arg("K"), + "Benchmark pure NVF4 GEMM (pre-allocated data, no quantization overhead)"); + + // ======================================================================== + // NVF4 GEMV for SM120 (M=1 path) + // ======================================================================== + + m.def("gemv_nvf4_available", []() { + return pygpukit_gemv_nvf4_available(); + }, "Check if NVF4 GEMV is available (SM120+)"); + + m.def("quantize_bf16_to_nvf4", [](const GPUArray& input, GPUArray& out_data, GPUArray& out_scale) { + if (input.dtype() != DataType::BFloat16) { + throw std::runtime_error("quantize_bf16_to_nvf4: input must be bfloat16"); + } + if (input.ndim() != 2) { + throw std::runtime_error("quantize_bf16_to_nvf4: input must be 2D [K, N]"); + } + + int K = input.shape()[0]; + int N = input.shape()[1]; + + cudaError_t err = pygpukit_quantize_bf16_to_nvf4( + input.data(), out_data.data(), out_scale.data(), + K, N, nullptr + ); + + if (err != cudaSuccess) { + throw std::runtime_error("quantize_bf16_to_nvf4 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("input"), py::arg("out_data"), py::arg("out_scale"), + "Quantize BF16 weights to NVF4 format for SM120 GEMV"); + + m.def("gemv_nvf4_bf16", [](const GPUArray& A, const GPUArray& B_data, const GPUArray& B_scale, GPUArray& C, float alpha) { + if (A.dtype() != DataType::BFloat16 || C.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemv_nvf4_bf16: A and C must be bfloat16"); + } + if (A.ndim() != 1) { + throw std::runtime_error("gemv_nvf4_bf16: A must be 1D [K]"); + } + + int K = A.shape()[0]; + int N = C.shape()[0]; + + cudaError_t err = pygpukit_gemv_nvf4_bf16( + A.data(), B_data.data(), B_scale.data(), C.data(), + K, N, alpha, nullptr + ); + + if (err != cudaSuccess) { + throw std::runtime_error("gemv_nvf4_bf16 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B_data"), py::arg("B_scale"), py::arg("C"), py::arg("alpha") = 1.0f, + "NVF4 GEMV for SM120: C[N] = alpha * A[K] @ B[K,N] (NVF4 quantized weights)"); + + m.def("gemv_bf16", [](const GPUArray& A, const GPUArray& B, GPUArray& C, float alpha, float beta) { + if (A.dtype() != DataType::BFloat16 || B.dtype() != DataType::BFloat16 || C.dtype() != DataType::BFloat16) { + throw std::runtime_error("gemv_bf16: all inputs must be bfloat16"); + } + if (A.ndim() != 1 || B.ndim() != 2 || C.ndim() != 1) { + throw std::runtime_error("gemv_bf16: A[K], B[K,N], C[N] dimensions required"); + } + + int K = A.shape()[0]; + int N = B.shape()[1]; + + if (B.shape()[0] != static_cast(K)) { + throw std::runtime_error("gemv_bf16: K dimension mismatch"); + } + if (C.shape()[0] != static_cast(N)) { + throw std::runtime_error("gemv_bf16: N dimension mismatch"); + } + + cudaError_t err = pygpukit_gemv_bf16( + A.data(), B.data(), C.data(), + K, N, alpha, beta, nullptr + ); + + if (err != cudaSuccess) { + throw std::runtime_error("gemv_bf16 failed: " + std::string(cudaGetErrorString(err))); + } + }, py::arg("A"), py::arg("B"), py::arg("C"), py::arg("alpha") = 1.0f, py::arg("beta") = 0.0f, + "BF16 GEMV: C[N] = alpha * A[K] @ B[K,N] + beta * C[N]"); + + m.def("nvf4_get_sizes", [](int K, int N) { + size_t data_size, scale_size; + pygpukit_nvf4_get_sizes(K, N, &data_size, &scale_size); + return py::make_tuple(data_size, scale_size); + }, py::arg("K"), py::arg("N"), + "Get buffer sizes for NVF4 quantization: returns (data_size, scale_size)"); + + // ======================================================================== + // FP8 GEMM auto-dispatch (selects best available backend) + // Priority: SM120 (if enabled) > SM90 > error + // ======================================================================== + + m.def("fp8_available", []() { + // Check all FP8 backends: SM120 (disabled), SM100, SM90 + return pygpukit_fp8_sm120_available() || + pygpukit_fp8_sm100_available() || + pygpukit_fp8_sm90_available(); + }, "Check if FP8 GEMM is available (any backend)"); + + m.def("gemm_fp8", [](const GPUArray& A, const GPUArray& B, GPUArray& D) { + if (A.dtype() != DataType::Float32 || B.dtype() != DataType::Float32 || D.dtype() != DataType::Float32) { + throw std::runtime_error("gemm_fp8: all inputs must be float32"); + } + if (A.ndim() != 2 || B.ndim() != 2 || D.ndim() != 2) { + throw std::runtime_error("gemm_fp8: all inputs must be 2D"); + } + + int M = A.shape()[0]; + int K = A.shape()[1]; + int N = B.shape()[1]; + + if (B.shape()[0] != static_cast(K)) { + throw std::runtime_error("gemm_fp8: A.shape[1] must equal B.shape[0]"); + } + if (D.shape()[0] != static_cast(M) || D.shape()[1] != static_cast(N)) { + throw std::runtime_error("gemm_fp8: D shape mismatch"); + } + + cudaError_t err; + + // Try SM120 first (when CUTLASS bug is fixed, this will be preferred) + if (pygpukit_fp8_sm120_available()) { + err = pygpukit_gemm_fp8_sm120( + static_cast(A.data()), + static_cast(B.data()), + static_cast(D.data()), + M, N, K, 1.0f, 0.0f, nullptr + ); + if (err == cudaSuccess) return; + // Fall through to SM100 if SM120 fails + } + + // Try SM100 (Blackwell datacenter - potential fallback for SM120) + if (pygpukit_fp8_sm100_available()) { + err = pygpukit_gemm_fp8_sm100( + static_cast(A.data()), + static_cast(B.data()), + static_cast(D.data()), + M, N, K, 1.0f, 0.0f, nullptr + ); + if (err == cudaSuccess) return; + // Fall through to SM90 if SM100 fails + } + + // Try SM90 (Hopper) + if (pygpukit_fp8_sm90_available()) { + err = pygpukit_gemm_fp8_sm90( + static_cast(A.data()), + static_cast(B.data()), + static_cast(D.data()), + M, N, K, 1.0f, 0.0f, nullptr + ); + if (err != cudaSuccess) { + throw std::runtime_error("gemm_fp8 (SM90) failed: " + std::string(cudaGetErrorString(err))); + } + return; + } + + throw std::runtime_error("gemm_fp8: no FP8 backend available (requires SM90+)"); + }, py::arg("A"), py::arg("B"), py::arg("D"), + "FP8 GEMM with auto backend selection: D = A @ B"); } diff --git a/native/jit/cublaslt_loader.cpp b/native/jit/cublaslt_loader.cpp index 5045097..51c355c 100644 --- a/native/jit/cublaslt_loader.cpp +++ b/native/jit/cublaslt_loader.cpp @@ -54,6 +54,7 @@ using PFN_cublasLtMatmulDescDestroy = cublasStatus_t (CUBLASAPI *)(cublasLtMatmu using PFN_cublasLtMatmulDescSetAttribute = cublasStatus_t (CUBLASAPI *)(cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, const void*, size_t); using PFN_cublasLtMatrixLayoutCreate = cublasStatus_t (CUBLASAPI *)(cublasLtMatrixLayout_t*, int, uint64_t, uint64_t, int64_t); using PFN_cublasLtMatrixLayoutDestroy = cublasStatus_t (CUBLASAPI *)(cublasLtMatrixLayout_t); +using PFN_cublasLtMatrixLayoutSetAttribute = cublasStatus_t (CUBLASAPI *)(cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, const void*, size_t); using PFN_cublasLtMatmul = cublasStatus_t (CUBLASAPI *)( cublasLtHandle_t, cublasLtMatmulDesc_t, const void*, const void*, cublasLtMatrixLayout_t, @@ -98,6 +99,7 @@ struct CublasLtState { PFN_cublasLtMatmulDescSetAttribute pfn_matmul_desc_set_attr{nullptr}; PFN_cublasLtMatrixLayoutCreate pfn_matrix_layout_create{nullptr}; PFN_cublasLtMatrixLayoutDestroy pfn_matrix_layout_destroy{nullptr}; + PFN_cublasLtMatrixLayoutSetAttribute pfn_matrix_layout_set_attr{nullptr}; PFN_cublasLtMatmul pfn_matmul{nullptr}; // Preference and heuristic function pointers (for CUDA Graph compatibility) @@ -109,22 +111,52 @@ struct CublasLtState { CublasLtState g_state; +// Get CUDA runtime major version +int get_cuda_major_version() { + int version = 0; + cudaError_t err = cudaRuntimeGetVersion(&version); + if (err != cudaSuccess) { + return 12; // Default to 12 if query fails + } + // version is encoded as major * 1000 + minor * 10 + return version / 1000; +} + // Search for cuBLASLt library in various locations std::vector get_search_paths() { std::vector paths; + // Get CUDA runtime version to match cuBLASLt version + int cuda_major = get_cuda_major_version(); + fprintf(stderr, "[cuBLASLt] CUDA runtime major version: %d\n", cuda_major); + #ifdef _WIN32 // Windows: Search for cublasLt64_*.dll - // Note: CUDA 13.x puts DLLs in bin/x64/ subdirectory + // Prioritize paths matching the CUDA runtime version + + if (cuda_major >= 13) { + // CUDA 13.x: bin/x64 subdirectory + paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v13.1\\bin\\x64"); + paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v13.0\\bin\\x64"); + } else { + // CUDA 12.x: bin directly + paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.9\\bin"); + paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.8\\bin"); + paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.6\\bin"); + paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.5\\bin"); + paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\bin"); + } - // 1. Check CUDA_PATH environment variable + // Then check CUDA_PATH as fallback const char* cuda_path = std::getenv("CUDA_PATH"); if (cuda_path) { - paths.push_back(std::string(cuda_path) + "\\bin\\x64"); // CUDA 13.x - paths.push_back(std::string(cuda_path) + "\\bin"); // CUDA 12.x and earlier + if (cuda_major >= 13) { + paths.push_back(std::string(cuda_path) + "\\bin\\x64"); + } + paths.push_back(std::string(cuda_path) + "\\bin"); } - // 2. Check PATH directories + // Check PATH directories as last resort const char* path_env = std::getenv("PATH"); if (path_env) { std::string path_str(path_env); @@ -139,21 +171,6 @@ std::vector get_search_paths() { } } - // 3. Common installation paths (CUDA 13.x uses bin/x64) - paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v13.1\\bin\\x64"); - paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v13.0\\bin\\x64"); - // CUDA 12.x uses bin directly - paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.9\\bin"); - paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.8\\bin"); - paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.6\\bin"); - paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.5\\bin"); - paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\bin"); - paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.3\\bin"); - paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.2\\bin"); - paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.1\\bin"); - paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.0\\bin"); - paths.push_back("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.8\\bin"); - #else // Linux/macOS: Search for libcublasLt.so @@ -191,7 +208,14 @@ std::vector get_search_paths() { #ifdef _WIN32 // Find cuBLASLt DLL in a directory (Windows) -std::string find_cublaslt_in_dir(const std::string& dir) { +// Prefers the version matching cuda_major +std::string find_cublaslt_in_dir(const std::string& dir, int cuda_major) { + // First, try the exact version matching the CUDA runtime + std::string preferred_path = dir + "\\cublasLt64_" + std::to_string(cuda_major) + ".dll"; + if (GetFileAttributesA(preferred_path.c_str()) != INVALID_FILE_ATTRIBUTES) { + return preferred_path; + } + // Search for cublasLt64_*.dll pattern (e.g., cublasLt64_12.dll, cublasLt64_13.dll) WIN32_FIND_DATAA find_data; std::string pattern = dir + "\\cublasLt64_*.dll"; @@ -209,14 +233,6 @@ std::string find_cublaslt_in_dir(const std::string& dir) { return exact_path; } - // Try specific version patterns for CUDA 13.x - for (int ver = 13; ver >= 11; --ver) { - std::string versioned_path = dir + "\\cublasLt64_" + std::to_string(ver) + ".dll"; - if (GetFileAttributesA(versioned_path.c_str()) != INVALID_FILE_ATTRIBUTES) { - return versioned_path; - } - } - return ""; } #else @@ -274,6 +290,7 @@ bool try_load(const std::string& path) { auto pfn_matmul_desc_set_attr = (PFN_cublasLtMatmulDescSetAttribute)GET_PROC(handle, "cublasLtMatmulDescSetAttribute"); auto pfn_matrix_layout_create = (PFN_cublasLtMatrixLayoutCreate)GET_PROC(handle, "cublasLtMatrixLayoutCreate"); auto pfn_matrix_layout_destroy = (PFN_cublasLtMatrixLayoutDestroy)GET_PROC(handle, "cublasLtMatrixLayoutDestroy"); + auto pfn_matrix_layout_set_attr = (PFN_cublasLtMatrixLayoutSetAttribute)GET_PROC(handle, "cublasLtMatrixLayoutSetAttribute"); auto pfn_matmul = (PFN_cublasLtMatmul)GET_PROC(handle, "cublasLtMatmul"); // Preference and heuristic functions (for CUDA Graph compatibility) @@ -285,7 +302,8 @@ bool try_load(const std::string& path) { // All core functions must be present if (!pfn_create || !pfn_destroy || !pfn_matmul_desc_create || !pfn_matmul_desc_destroy || !pfn_matmul_desc_set_attr || - !pfn_matrix_layout_create || !pfn_matrix_layout_destroy || !pfn_matmul) { + !pfn_matrix_layout_create || !pfn_matrix_layout_destroy || + !pfn_matrix_layout_set_attr || !pfn_matmul) { FREE_LIBRARY(handle); return false; } @@ -314,6 +332,7 @@ bool try_load(const std::string& path) { g_state.pfn_matmul_desc_set_attr = pfn_matmul_desc_set_attr; g_state.pfn_matrix_layout_create = pfn_matrix_layout_create; g_state.pfn_matrix_layout_destroy = pfn_matrix_layout_destroy; + g_state.pfn_matrix_layout_set_attr = pfn_matrix_layout_set_attr; g_state.pfn_matmul = pfn_matmul; // Preference and heuristic function pointers @@ -343,9 +362,14 @@ bool initialize() { // Search for cuBLASLt auto search_paths = get_search_paths(); + int cuda_major = get_cuda_major_version(); for (const auto& dir : search_paths) { +#ifdef _WIN32 + std::string cublaslt_path = find_cublaslt_in_dir(dir, cuda_major); +#else std::string cublaslt_path = find_cublaslt_in_dir(dir); +#endif if (!cublaslt_path.empty() && try_load(cublaslt_path)) { g_state.available.store(true, std::memory_order_relaxed); g_state.initialized.store(true, std::memory_order_release); @@ -367,6 +391,27 @@ bool is_available() { } // First call: do full initialization initialize(); + + // SM 120 (Blackwell GeForce) has cuBLASLt compatibility issues + // AlgoGetHeuristic returns NOT_SUPPORTED (status=15) for most operations + // Disable cuBLASLt on SM >= 120 unless PYGPUKIT_CUBLASLT_SM120=1 + if (g_state.available.load(std::memory_order_relaxed)) { + int device_id = 0; + cudaGetDevice(&device_id); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device_id); + int sm_version = props.major * 10 + props.minor; + if (sm_version >= 120) { + const char* force_sm120 = std::getenv("PYGPUKIT_CUBLASLT_SM120"); + if (force_sm120 && std::string(force_sm120) == "1") { + fprintf(stderr, "[cuBLASLt] Force-enabled on SM %d (PYGPUKIT_CUBLASLT_SM120=1)\n", sm_version); + } else { + fprintf(stderr, "[cuBLASLt] Disabled on SM %d (set PYGPUKIT_CUBLASLT_SM120=1 to force)\n", sm_version); + g_state.available.store(false, std::memory_order_relaxed); + } + } + } + return g_state.available.load(std::memory_order_relaxed); } @@ -438,6 +483,16 @@ cublasStatus_t matrix_layout_destroy(cublasLtMatrixLayout_t matLayout) { return g_state.pfn_matrix_layout_destroy(matLayout); } +cublasStatus_t matrix_layout_set_attribute( + cublasLtMatrixLayout_t matLayout, + cublasLtMatrixLayoutAttribute_t attr, + const void* buf, + size_t sizeInBytes +) { + if (!is_available()) return CUBLAS_STATUS_NOT_INITIALIZED; + return g_state.pfn_matrix_layout_set_attr(matLayout, attr, buf, sizeInBytes); +} + cublasStatus_t matmul( cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, @@ -470,6 +525,7 @@ cublasStatus_t matmul( cublasLtHandle_t get_handle() { if (!is_available()) { + fprintf(stderr, "[cuBLASLt] get_handle: not available\n"); return nullptr; } @@ -485,10 +541,33 @@ cublasLtHandle_t get_handle() { return g_state.lt_handle; } + // Ensure CUDA is initialized before creating cuBLASLt handle + int device = -1; + cudaError_t cuda_err = cudaGetDevice(&device); + fprintf(stderr, "[cuBLASLt] cudaGetDevice returned: %d, device=%d\n", static_cast(cuda_err), device); + if (cuda_err != cudaSuccess || device < 0) { + // Force CUDA initialization + fprintf(stderr, "[cuBLASLt] Calling cudaSetDevice(0)...\n"); + cuda_err = cudaSetDevice(0); + if (cuda_err != cudaSuccess) { + fprintf(stderr, "[cuBLASLt] ERROR: Failed to initialize CUDA: %d\n", static_cast(cuda_err)); + return nullptr; + } + // Try to get device again + cudaGetDevice(&device); + fprintf(stderr, "[cuBLASLt] After cudaSetDevice, device=%d\n", device); + } + + // Sync device to ensure context is ready + cudaDeviceSynchronize(); + cublasLtHandle_t handle = nullptr; cublasStatus_t status = g_state.pfn_create(&handle); + fprintf(stderr, "[cuBLASLt] cublasLtCreate returned: %d, handle=%p\n", static_cast(status), handle); if (status == CUBLAS_STATUS_SUCCESS) { g_state.lt_handle = handle; + } else { + fprintf(stderr, "[cuBLASLt] ERROR: Failed to create cuBLASLt handle!\n"); } return g_state.lt_handle; @@ -824,5 +903,178 @@ cudaError_t gemm_bf16( return cudaSuccess; } +cudaError_t gemm_strided_batched_fp32( + const float* A, const float* B, float* C, + int M, int N, int K, int batch_count, + int64_t strideA, int64_t strideB, int64_t strideC, + cudaStream_t stream +) { + fprintf(stderr, "[cuBLASLt] gemm_strided_batched_fp32: M=%d N=%d K=%d batch=%d strideA=%lld strideB=%lld strideC=%lld\n", + M, N, K, batch_count, (long long)strideA, (long long)strideB, (long long)strideC); + + g_last_cublaslt_error = 0; + g_last_cublaslt_step = 0; + + cublasLtHandle_t handle = get_handle(); + if (!handle) { + g_last_cublaslt_step = 1; + g_last_cublaslt_error = -1; + return cudaErrorNotReady; + } + + cublasStatus_t status; + + // Create matmul descriptor + cublasLtMatmulDesc_t operationDesc = nullptr; + status = matmul_desc_create(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); + if (status != CUBLAS_STATUS_SUCCESS) { + g_last_cublaslt_step = 2; + g_last_cublaslt_error = static_cast(status); + return cudaErrorUnknown; + } + + // Set transpose attributes (NN for row-major: C = A @ B) + // cuBLASLt is column-major, so we compute C^T = B^T @ A^T + cublasOperation_t transA = CUBLAS_OP_N; + cublasOperation_t transB = CUBLAS_OP_N; + matmul_desc_set_attribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transA, sizeof(transA)); + matmul_desc_set_attribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transB, sizeof(transB)); + + // Create matrix layouts with batch info (swapped for row-major) + // Row-major C[M,N] = A[M,K] @ B[K,N] + // Column-major: C^T[N,M] = B^T[N,K] @ A^T[K,M] + cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr; + + // B^T layout: [N, K] with ld=N, stride between batches + fprintf(stderr, "[cuBLASLt] Creating Bdesc: rows=%d cols=%d ld=%d\n", N, K, N); + status = matrix_layout_create(&Bdesc, CUDA_R_32F, N, K, N); + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, "[cuBLASLt] Bdesc creation failed: %d\n", static_cast(status)); + g_last_cublaslt_step = 3; + g_last_cublaslt_error = static_cast(status); + matmul_desc_destroy(operationDesc); + return cudaErrorUnknown; + } + status = matrix_layout_set_attribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)); + fprintf(stderr, "[cuBLASLt] Bdesc batch_count set: %d\n", static_cast(status)); + status = matrix_layout_set_attribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideB, sizeof(strideB)); + fprintf(stderr, "[cuBLASLt] Bdesc stride set: %d\n", static_cast(status)); + + // A^T layout: [K, M] with ld=K, stride between batches + fprintf(stderr, "[cuBLASLt] Creating Adesc: rows=%d cols=%d ld=%d\n", K, M, K); + status = matrix_layout_create(&Adesc, CUDA_R_32F, K, M, K); + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, "[cuBLASLt] Adesc creation failed: %d\n", static_cast(status)); + g_last_cublaslt_step = 4; + g_last_cublaslt_error = static_cast(status); + matrix_layout_destroy(Bdesc); + matmul_desc_destroy(operationDesc); + return cudaErrorUnknown; + } + status = matrix_layout_set_attribute(Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)); + fprintf(stderr, "[cuBLASLt] Adesc batch_count set: %d\n", static_cast(status)); + status = matrix_layout_set_attribute(Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideA, sizeof(strideA)); + fprintf(stderr, "[cuBLASLt] Adesc stride set: %d\n", static_cast(status)); + + // C^T layout: [N, M] with ld=N, stride between batches + fprintf(stderr, "[cuBLASLt] Creating Cdesc: rows=%d cols=%d ld=%d\n", N, M, N); + status = matrix_layout_create(&Cdesc, CUDA_R_32F, N, M, N); + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, "[cuBLASLt] Cdesc creation failed: %d\n", static_cast(status)); + g_last_cublaslt_step = 5; + g_last_cublaslt_error = static_cast(status); + matrix_layout_destroy(Adesc); + matrix_layout_destroy(Bdesc); + matmul_desc_destroy(operationDesc); + return cudaErrorUnknown; + } + status = matrix_layout_set_attribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)); + fprintf(stderr, "[cuBLASLt] Cdesc batch_count set: %d\n", static_cast(status)); + status = matrix_layout_set_attribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideC, sizeof(strideC)); + fprintf(stderr, "[cuBLASLt] Cdesc stride set: %d\n", static_cast(status)); + + float alpha = 1.0f; + float beta = 0.0f; + + // Select algorithm for batched GEMM using heuristics + cublasLtMatmulAlgo_t algo; + bool has_algo = false; + void* workspace = nullptr; + size_t workspaceSize = 0; + + if (g_state.pfn_pref_create && g_state.pfn_algo_get_heuristic) { + cublasLtMatmulPreference_t preference = nullptr; + status = g_state.pfn_pref_create(&preference); + if (status == CUBLAS_STATUS_SUCCESS && preference) { + constexpr size_t MAX_WORKSPACE = 32 * 1024 * 1024; + g_state.pfn_pref_set_attr(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &MAX_WORKSPACE, sizeof(MAX_WORKSPACE)); + + cublasLtMatmulHeuristicResult_struct heuristicResult; + int returnedResults = 0; + + status = g_state.pfn_algo_get_heuristic( + handle, operationDesc, + Bdesc, Adesc, // Swapped for row-major + Cdesc, Cdesc, + preference, 1, &heuristicResult, &returnedResults + ); + + fprintf(stderr, "[cuBLASLt] Batched AlgoGetHeuristic: status=%d, results=%d\n", + static_cast(status), returnedResults); + + if (status == CUBLAS_STATUS_SUCCESS && returnedResults > 0) { + algo = heuristicResult.algo; + workspaceSize = heuristicResult.workspaceSize; + has_algo = true; + + if (workspaceSize > 0) { + CUdeviceptr dptr = 0; + CUresult err = cuMemAlloc(&dptr, workspaceSize); + if (err == CUDA_SUCCESS) { + workspace = reinterpret_cast(dptr); + } + } + } + + g_state.pfn_pref_destroy(preference); + } + } + + // Execute batched matmul + fprintf(stderr, "[cuBLASLt] Calling cublasLtMatmul (has_algo=%d, ws=%zu)...\n", has_algo, workspaceSize); + status = g_state.pfn_matmul( + handle, operationDesc, + &alpha, + B, Bdesc, + A, Adesc, + &beta, + C, Cdesc, + C, Cdesc, + has_algo ? &algo : nullptr, + workspace, workspaceSize, stream + ); + fprintf(stderr, "[cuBLASLt] cublasLtMatmul returned: %d\n", static_cast(status)); + + // Free workspace if allocated + if (workspace) { + cuMemFree(reinterpret_cast(workspace)); + } + + // Cleanup + matrix_layout_destroy(Cdesc); + matrix_layout_destroy(Adesc); + matrix_layout_destroy(Bdesc); + matmul_desc_destroy(operationDesc); + + if (status != CUBLAS_STATUS_SUCCESS) { + g_last_cublaslt_step = 6; + g_last_cublaslt_error = static_cast(status); + return cudaErrorUnknown; + } + + return cudaSuccess; +} + } // namespace cublaslt } // namespace pygpukit diff --git a/native/jit/cublaslt_loader.hpp b/native/jit/cublaslt_loader.hpp index bd66324..530783a 100644 --- a/native/jit/cublaslt_loader.hpp +++ b/native/jit/cublaslt_loader.hpp @@ -71,6 +71,19 @@ enum cublasLtMatmulPreferenceAttributes_t { CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES = 1 }; +// Matrix layout attributes for batched GEMM +enum cublasLtMatrixLayoutAttribute_t { + CUBLASLT_MATRIX_LAYOUT_ORDER = 1, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT = 5, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET = 6 +}; + +// Matrix order +enum cublasLtOrder_t { + CUBLASLT_ORDER_COL = 0, + CUBLASLT_ORDER_ROW = 1 +}; + // Algorithm structure (64 bytes as per cuBLAS documentation) struct cublasLtMatmulAlgo_t { uint64_t data[8]; @@ -130,6 +143,13 @@ cublasStatus_t matrix_layout_create( cublasStatus_t matrix_layout_destroy(cublasLtMatrixLayout_t matLayout); +cublasStatus_t matrix_layout_set_attribute( + cublasLtMatrixLayout_t matLayout, + cublasLtMatrixLayoutAttribute_t attr, + const void* buf, + size_t sizeInBytes +); + cublasStatus_t matmul( cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, @@ -177,6 +197,15 @@ cudaError_t gemm_bf16( cudaStream_t stream = nullptr ); +// Strided Batched FP32 GEMM: C[b] = A[b] @ B[b] for b in [0, batch_count) +// A: [batch_count, M, K], B: [batch_count, K, N], C: [batch_count, M, N] +cudaError_t gemm_strided_batched_fp32( + const float* A, const float* B, float* C, + int M, int N, int K, int batch_count, + int64_t strideA, int64_t strideB, int64_t strideC, + cudaStream_t stream = nullptr +); + // Debug functions int get_last_cublaslt_error(); // Returns last cuBLASLt status code int get_last_cublaslt_step(); // Returns which step failed (1-6) diff --git a/native/ops/audio/audio.cu b/native/ops/audio/audio.cu index b82eae1..8753d0b 100644 --- a/native/ops/audio/audio.cu +++ b/native/ops/audio/audio.cu @@ -183,13 +183,16 @@ GPUArray resample(const GPUArray& input, int src_rate, int dst_rate) { throw std::runtime_error("resample: input must be Float32"); } - // Currently only support 48kHz -> 16kHz (3:1 decimation) - if (src_rate != 48000 || dst_rate != 16000) { - throw std::runtime_error("resample: currently only 48000 -> 16000 is supported"); + if (src_rate == dst_rate) { + // No resampling needed, return copy + GPUArray output(input.shape(), DataType::Float32); + cudaMemcpy(output.data(), input.data(), input.size() * sizeof(float), cudaMemcpyDeviceToDevice); + return output; } int in_len = static_cast(input.size()); - int out_len = in_len / 3; // 3:1 decimation + int out_len = static_cast(static_cast(in_len) * dst_rate / src_rate); + float ratio = static_cast(src_rate) / static_cast(dst_rate); GPUArray output({static_cast(out_len)}, DataType::Float32); @@ -198,13 +201,24 @@ GPUArray resample(const GPUArray& input, int src_rate, int dst_rate) { cudaStream_t stream = internal::get_capture_stream(); - resample_polyphase_kernel<<>>( - static_cast(input.data()), - static_cast(output.data()), - in_len, - out_len); + // Use optimized polyphase filter for 48kHz -> 16kHz + if (src_rate == 48000 && dst_rate == 16000) { + resample_polyphase_kernel<<>>( + static_cast(input.data()), + static_cast(output.data()), + in_len, + out_len); + } else { + // Generic linear interpolation for other sample rates + resample_linear_kernel<<>>( + static_cast(input.data()), + static_cast(output.data()), + in_len, + out_len, + ratio); + } - sync_and_check("resample_polyphase kernel failed"); + sync_and_check("resample kernel failed"); return output; } diff --git a/native/ops/audio/audio_kernels.cuh b/native/ops/audio/audio_kernels.cuh index d02a88c..aa186a4 100644 --- a/native/ops/audio/audio_kernels.cuh +++ b/native/ops/audio/audio_kernels.cuh @@ -178,6 +178,29 @@ __global__ void resample_polyphase_kernel( output[out_idx] = sum; } +// Generic linear interpolation resampler for arbitrary sample rates +__global__ void resample_linear_kernel( + const float* __restrict__ input, + float* __restrict__ output, + int in_len, + int out_len, + float ratio) // ratio = src_rate / dst_rate +{ + int out_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (out_idx >= out_len) return; + + // Map output sample to input position (floating point) + float in_pos = out_idx * ratio; + int in_idx = static_cast(in_pos); + float frac = in_pos - in_idx; + + // Linear interpolation between adjacent samples + float sample0 = (in_idx < in_len) ? input[in_idx] : 0.0f; + float sample1 = (in_idx + 1 < in_len) ? input[in_idx + 1] : sample0; + + output[out_idx] = sample0 + frac * (sample1 - sample0); +} + // ============================================================================ // Ring Buffer Operations (for streaming) // ============================================================================ @@ -1908,6 +1931,93 @@ __global__ void spectral_contrast_kernel( contrast[frame_idx * n_bands + band_idx] = logf(peak + 1e-10f) - logf(valley + 1e-10f); } +// ============================================================================ +// Conv1D - 1D convolution for audio/signal processing +// Input: [batch, in_channels, length] +// Kernel: [out_channels, in_channels, kernel_size] +// Output: [batch, out_channels, out_length] +// ============================================================================ + +__global__ void conv1d_f32_kernel( + const float* __restrict__ input, // [B, C_in, L] + const float* __restrict__ weight, // [C_out, C_in, K] + const float* __restrict__ bias, // [C_out] or nullptr + float* __restrict__ output, // [B, C_out, L_out] + int batch, int in_channels, int out_channels, + int in_length, int kernel_size, int stride, int padding +) { + int out_length = (in_length + 2 * padding - kernel_size) / stride + 1; + int total = batch * out_channels * out_length; + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + + int b = idx / (out_channels * out_length); + int rem = idx % (out_channels * out_length); + int oc = rem / out_length; + int ol = rem % out_length; + + float sum = 0.0f; + int in_start = ol * stride - padding; + + for (int ic = 0; ic < in_channels; ++ic) { + for (int k = 0; k < kernel_size; ++k) { + int il = in_start + k; + if (il >= 0 && il < in_length) { + float in_val = input[b * in_channels * in_length + ic * in_length + il]; + float w_val = weight[oc * in_channels * kernel_size + ic * kernel_size + k]; + sum += in_val * w_val; + } + } + } + + if (bias != nullptr) { + sum += bias[oc]; + } + + output[b * out_channels * out_length + oc * out_length + ol] = sum; +} + +__global__ void conv1d_f16_kernel( + const __half* __restrict__ input, + const __half* __restrict__ weight, + const __half* __restrict__ bias, + __half* __restrict__ output, + int batch, int in_channels, int out_channels, + int in_length, int kernel_size, int stride, int padding +) { + int out_length = (in_length + 2 * padding - kernel_size) / stride + 1; + int total = batch * out_channels * out_length; + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + + int b = idx / (out_channels * out_length); + int rem = idx % (out_channels * out_length); + int oc = rem / out_length; + int ol = rem % out_length; + + float sum = 0.0f; + int in_start = ol * stride - padding; + + for (int ic = 0; ic < in_channels; ++ic) { + for (int k = 0; k < kernel_size; ++k) { + int il = in_start + k; + if (il >= 0 && il < in_length) { + float in_val = __half2float(input[b * in_channels * in_length + ic * in_length + il]); + float w_val = __half2float(weight[oc * in_channels * kernel_size + ic * kernel_size + k]); + sum += in_val * w_val; + } + } + } + + if (bias != nullptr) { + sum += __half2float(bias[oc]); + } + + output[b * out_channels * out_length + oc * out_length + ol] = __float2half(sum); +} + } // namespace audio } // namespace ops } // namespace pygpukit diff --git a/native/ops/elementwise/elementwise.cu b/native/ops/elementwise/elementwise.cu index a9c6df7..e0750e4 100644 --- a/native/ops/elementwise/elementwise.cu +++ b/native/ops/elementwise/elementwise.cu @@ -262,5 +262,117 @@ GPUArray div(const GPUArray& a, const GPUArray& b) { return c; } +// ============================================================================ +// Clamp +// ============================================================================ + +void clamp(const GPUArray& a, GPUArray& c, float min_val, float max_val) { + validate_same_shape(a, c, "clamp"); + validate_same_dtype(a, c, "clamp"); + + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("clamp only supports float types"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + clamp_f32_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), + min_val, max_val, n); + break; + case DataType::Float16: + clamp_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(c.data()), + min_val, max_val, n); + break; + case DataType::BFloat16: + clamp_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(c.data()), + min_val, max_val, n); + break; + default: + break; + } + sync_and_check("clamp kernel failed"); +} + +GPUArray clamp(const GPUArray& a, float min_val, float max_val) { + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("clamp only supports float types"); + } + GPUArray c(a.shape(), a.dtype()); + clamp(a, c, min_val, max_val); + return c; +} + +// ============================================================================ +// Where (conditional select) +// ============================================================================ + +void where(const GPUArray& cond, const GPUArray& a, const GPUArray& b, GPUArray& c) { + validate_same_shape(a, b, "where"); + validate_same_shape(a, c, "where"); + validate_same_dtype(a, b, "where"); + validate_same_dtype(a, c, "where"); + + if (cond.size() != a.size()) { + throw std::runtime_error("where: condition shape must match input shape"); + } + if (cond.dtype() != DataType::UInt8 && cond.dtype() != DataType::Int8) { + throw std::runtime_error("where: condition must be uint8 or int8 type (boolean)"); + } + + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("where only supports float types"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + where_f32_kernel<<>>( + static_cast(cond.data()), + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + where_f16_kernel<<>>( + static_cast(cond.data()), + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + where_bf16_kernel<<>>( + static_cast(cond.data()), + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + default: + break; + } + sync_and_check("where kernel failed"); +} + +GPUArray where(const GPUArray& cond, const GPUArray& a, const GPUArray& b) { + GPUArray c(a.shape(), a.dtype()); + where(cond, a, b, c); + return c; +} + } // namespace ops } // namespace pygpukit diff --git a/native/ops/elementwise/elementwise_kernels.cuh b/native/ops/elementwise/elementwise_kernels.cuh index 64dd689..d4220a8 100644 --- a/native/ops/elementwise/elementwise_kernels.cuh +++ b/native/ops/elementwise/elementwise_kernels.cuh @@ -197,6 +197,66 @@ __global__ void div_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, } } +// ============================================================================ +// Clamp/Clip kernels - clamp values to [min, max] range +// ============================================================================ + +__global__ void clamp_f32_kernel(const float* a, float* c, float min_val, float max_val, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = fminf(fmaxf(a[idx], min_val), max_val); + } +} + +__global__ void clamp_f16_kernel(const __half* a, __half* c, float min_val, float max_val, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float v = __half2float(a[idx]); + c[idx] = __float2half(fminf(fmaxf(v, min_val), max_val)); + } +} + +__global__ void clamp_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, float min_val, float max_val, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float v = bf16_to_float(a[idx]); + c[idx] = float_to_bf16(fminf(fmaxf(v, min_val), max_val)); + } +} + +// ============================================================================ +// Where/Select kernels - conditional selection: out = cond ? a : b +// ============================================================================ + +__global__ void where_f32_kernel(const uint8_t* cond, const float* a, const float* b, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = cond[idx] ? a[idx] : b[idx]; + } +} + +__global__ void where_f16_kernel(const uint8_t* cond, const __half* a, const __half* b, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = cond[idx] ? a[idx] : b[idx]; + } +} + +__global__ void where_bf16_kernel(const uint8_t* cond, const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = cond[idx] ? a[idx] : b[idx]; + } +} + +// Scalar variants for where (useful for masking with constant) +__global__ void where_scalar_f32_kernel(const uint8_t* cond, const float* a, float b, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = cond[idx] ? a[idx] : b; + } +} + } // namespace elementwise } // namespace ops } // namespace pygpukit diff --git a/native/ops/gemv/benchmark_gemv.cu b/native/ops/gemv/benchmark_gemv.cu new file mode 100644 index 0000000..f4e5a06 --- /dev/null +++ b/native/ops/gemv/benchmark_gemv.cu @@ -0,0 +1,394 @@ +/** + * GEMV Benchmark: CUTLASS vs cuBLASLt + * + * Compares our CUTLASS-based GEMV with cuBLASLt GEMV under identical conditions. + * + * Build: + * nvcc -std=c++17 -O3 -arch=sm_86 benchmark_gemv.cu -lcublasLt -o benchmark_gemv + * + * Usage: + * ./benchmark_gemv [K] [N] + * Default: K=4096, N=4096 (typical LLM hidden size) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gemv_cutlass.cuh" + +// ============================================================================ +// Benchmark Configuration +// ============================================================================ + +constexpr int WARMUP_ITERATIONS = 20; +constexpr int BENCHMARK_ITERATIONS = 100; + +// Common LLM hidden sizes for benchmarking +struct BenchmarkCase { + int K; + int N; + const char* name; +}; + +const BenchmarkCase BENCHMARK_CASES[] = { + // Small models (< 1B params) + {768, 768, "768x768 (BERT-base)"}, + {1024, 1024, "1024x1024 (GPT-small)"}, + {2048, 2048, "2048x2048 (GPT-medium)"}, + + // Medium models (1-7B params) + {4096, 4096, "4096x4096 (LLaMA-7B hidden)"}, + {4096, 11008, "4096x11008 (LLaMA-7B MLP)"}, + {4096, 14336, "4096x14336 (Qwen-7B MLP)"}, + + // Large models (7-70B params) + {5120, 5120, "5120x5120 (LLaMA-13B)"}, + {8192, 8192, "8192x8192 (LLaMA-70B hidden)"}, + {8192, 28672, "8192x28672 (LLaMA-70B MLP)"}, + + // Extreme cases + {16384, 16384, "16384x16384 (large)"}, + {4096, 32768, "4096x32768 (wide)"}, + {32768, 4096, "32768x4096 (tall)"}, +}; + +// ============================================================================ +// cuBLASLt GEMV Wrapper +// ============================================================================ + +class CuBLASLtGemv { +public: + CuBLASLtGemv() { + cublasLtCreate(&handle_); + } + + ~CuBLASLtGemv() { + cublasLtDestroy(handle_); + } + + // BF16 GEMV using cuBLASLt + // C[1,N] = A[1,K] @ B[K,N] + cudaError_t gemv_bf16( + const __nv_bfloat16* A, // [1, K] + const __nv_bfloat16* B, // [K, N] + __nv_bfloat16* C, // [1, N] + int K, int N, + float alpha, float beta, + cudaStream_t stream + ) { + // cuBLASLt uses column-major, so we compute C^T = B^T @ A^T + // For row-major: C[1,N] = A[1,K] @ B[K,N] + // In col-major view: C^T[N,1] = B^T[N,K] @ A^T[K,1] + // + // However, for M=1, it's simpler to just call GEMM with M=1 + // cuBLASLt GEMM: D = alpha * A @ B + beta * C + // With m=1, n=N, k=K in column-major terms + + cublasLtMatmulDesc_t operationDesc; + cublasLtMatrixLayout_t Adesc, Bdesc, Cdesc, Ddesc; + cublasLtMatmulPreference_t preference; + cublasLtMatmulHeuristicResult_t heuristicResult; + int returnedResults = 0; + + cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; + cudaDataType_t scaleType = CUDA_R_32F; + cudaDataType_t dataType = CUDA_R_16BF; + + // Create operation descriptor + cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); + + // Set transpose operations for row-major inputs + // For row-major C = A @ B: + // Use CUBLAS_OP_N for both since we're treating row-major as transposed col-major + cublasOperation_t transA = CUBLAS_OP_T; + cublasOperation_t transB = CUBLAS_OP_N; + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transA, sizeof(transA)); + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transB, sizeof(transB)); + + // Matrix layouts (column-major perspective) + // A: [K, 1] in col-major = [1, K] row-major + // B: [K, N] in col-major = [N, K] row-major, but we have [K, N] row-major + // Need to swap and transpose + + // Actually, let's use the standard row-major approach: + // For row-major C[M,N] = A[M,K] @ B[K,N]: + // Compute as: C^T[N,M] = B^T[N,K] @ A^T[K,M] + // In cuBLASLt terms with ColumnMajor default: + // D[N,M] = B[N,K] @ A[K,M] where matrices are stored as their transposes + + // For M=1: + // D[N,1] = B[N,K] @ A[K,1] + // m=N, n=1, k=K + + int m = N; + int n = 1; + int k = K; + + int lda = K; // Leading dim of A (row-major A[1,K]) + int ldb = N; // Leading dim of B (row-major B[K,N]) + int ldc = N; // Leading dim of C (row-major C[1,N]) + + // Create matrix layouts + // A as [K, 1] column-major (which is A^T of our row-major [1, K]) + cublasLtMatrixLayoutCreate(&Adesc, dataType, k, n, lda); + + // B as [N, K] column-major (which is B^T of our row-major [K, N]) + cublasLtMatrixLayoutCreate(&Bdesc, dataType, m, k, ldb); + + // C/D as [N, 1] column-major (which is C^T of our row-major [1, N]) + cublasLtMatrixLayoutCreate(&Cdesc, dataType, m, n, ldc); + cublasLtMatrixLayoutCreate(&Ddesc, dataType, m, n, ldc); + + // Create preference + cublasLtMatmulPreferenceCreate(&preference); + size_t workspaceSize = 0; + cublasLtMatmulPreferenceSetAttribute(preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); + + // Get heuristic + cublasLtMatmulAlgoGetHeuristic(handle_, operationDesc, Bdesc, Adesc, Cdesc, Ddesc, + preference, 1, &heuristicResult, &returnedResults); + + if (returnedResults == 0) { + // Cleanup + cublasLtMatmulPreferenceDestroy(preference); + cublasLtMatrixLayoutDestroy(Ddesc); + cublasLtMatrixLayoutDestroy(Cdesc); + cublasLtMatrixLayoutDestroy(Bdesc); + cublasLtMatrixLayoutDestroy(Adesc); + cublasLtMatmulDescDestroy(operationDesc); + return cudaErrorNotSupported; + } + + // Execute GEMM + // Note: For row-major, we swap A and B pointers + cublasStatus_t status = cublasLtMatmul(handle_, + operationDesc, + &alpha, + B, Bdesc, // First operand (was A in col-major) + A, Adesc, // Second operand (was B in col-major) + &beta, + C, Cdesc, + C, Ddesc, // Output + &heuristicResult.algo, + nullptr, 0, + stream); + + // Cleanup + cublasLtMatmulPreferenceDestroy(preference); + cublasLtMatrixLayoutDestroy(Ddesc); + cublasLtMatrixLayoutDestroy(Cdesc); + cublasLtMatrixLayoutDestroy(Bdesc); + cublasLtMatrixLayoutDestroy(Adesc); + cublasLtMatmulDescDestroy(operationDesc); + + return (status == CUBLAS_STATUS_SUCCESS) ? cudaSuccess : cudaErrorUnknown; + } + +private: + cublasLtHandle_t handle_; +}; + +// ============================================================================ +// Benchmark Utilities +// ============================================================================ + +void initialize_random_bf16(__nv_bfloat16* data, size_t count) { + std::vector host(count); + for (size_t i = 0; i < count; ++i) { + host[i] = (static_cast(rand()) / RAND_MAX - 0.5f) * 0.1f; + } + std::vector<__nv_bfloat16> host_bf16(count); + for (size_t i = 0; i < count; ++i) { + host_bf16[i] = __float2bfloat16(host[i]); + } + cudaMemcpy(data, host_bf16.data(), count * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice); +} + +float compute_max_error_bf16(__nv_bfloat16* A, __nv_bfloat16* B, size_t count) { + std::vector<__nv_bfloat16> host_A(count), host_B(count); + cudaMemcpy(host_A.data(), A, count * sizeof(__nv_bfloat16), cudaMemcpyDeviceToHost); + cudaMemcpy(host_B.data(), B, count * sizeof(__nv_bfloat16), cudaMemcpyDeviceToHost); + + float max_err = 0.0f; + for (size_t i = 0; i < count; ++i) { + float a = __bfloat162float(host_A[i]); + float b = __bfloat162float(host_B[i]); + float err = std::abs(a - b); + max_err = std::max(max_err, err); + } + return max_err; +} + +// ============================================================================ +// Benchmark Runner +// ============================================================================ + +struct BenchmarkResult { + double cutlass_us; + double cublaslt_us; + float speedup; + float max_error; +}; + +BenchmarkResult run_benchmark(int K, int N, CuBLASLtGemv& cublas) { + BenchmarkResult result; + + // Allocate device memory + __nv_bfloat16 *d_A, *d_B, *d_C_cutlass, *d_C_cublas; + cudaMalloc(&d_A, 1 * K * sizeof(__nv_bfloat16)); + cudaMalloc(&d_B, K * N * sizeof(__nv_bfloat16)); + cudaMalloc(&d_C_cutlass, 1 * N * sizeof(__nv_bfloat16)); + cudaMalloc(&d_C_cublas, 1 * N * sizeof(__nv_bfloat16)); + + // Initialize with random data + initialize_random_bf16(d_A, K); + initialize_random_bf16(d_B, K * N); + cudaMemset(d_C_cutlass, 0, N * sizeof(__nv_bfloat16)); + cudaMemset(d_C_cublas, 0, N * sizeof(__nv_bfloat16)); + + // Create CUDA events for timing + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + // ======================================================================== + // Benchmark CUTLASS GEMV + // ======================================================================== + + // Warmup + for (int i = 0; i < WARMUP_ITERATIONS; ++i) { + pygpukit::ops::gemv::launch_gemv_bf16(d_A, d_B, d_C_cutlass, K, N); + } + cudaDeviceSynchronize(); + + // Timed iterations + cudaEventRecord(start); + for (int i = 0; i < BENCHMARK_ITERATIONS; ++i) { + pygpukit::ops::gemv::launch_gemv_bf16(d_A, d_B, d_C_cutlass, K, N); + } + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + float cutlass_ms; + cudaEventElapsedTime(&cutlass_ms, start, stop); + result.cutlass_us = (cutlass_ms * 1000.0) / BENCHMARK_ITERATIONS; + + // ======================================================================== + // Benchmark cuBLASLt GEMV + // ======================================================================== + + // Warmup + for (int i = 0; i < WARMUP_ITERATIONS; ++i) { + cublas.gemv_bf16(d_A, d_B, d_C_cublas, K, N, 1.0f, 0.0f, nullptr); + } + cudaDeviceSynchronize(); + + // Timed iterations + cudaEventRecord(start); + for (int i = 0; i < BENCHMARK_ITERATIONS; ++i) { + cublas.gemv_bf16(d_A, d_B, d_C_cublas, K, N, 1.0f, 0.0f, nullptr); + } + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + float cublaslt_ms; + cudaEventElapsedTime(&cublaslt_ms, start, stop); + result.cublaslt_us = (cublaslt_ms * 1000.0) / BENCHMARK_ITERATIONS; + + // ======================================================================== + // Compute error + // ======================================================================== + + result.max_error = compute_max_error_bf16(d_C_cutlass, d_C_cublas, N); + result.speedup = result.cublaslt_us / result.cutlass_us; + + // Cleanup + cudaEventDestroy(start); + cudaEventDestroy(stop); + cudaFree(d_A); + cudaFree(d_B); + cudaFree(d_C_cutlass); + cudaFree(d_C_cublas); + + return result; +} + +// ============================================================================ +// Main +// ============================================================================ + +int main(int argc, char* argv[]) { + // Print device info + int device; + cudaGetDevice(&device); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device); + printf("Device: %s (SM %d%d)\n", props.name, props.major, props.minor); + printf("Memory: %.1f GB\n", props.totalGlobalMem / 1e9); + printf("\n"); + + // Initialize cuBLASLt + CuBLASLtGemv cublas; + + // Print header + printf("GEMV Benchmark: CUTLASS vs cuBLASLt (BF16, M=1)\n"); + printf("Warmup: %d iterations, Benchmark: %d iterations\n", WARMUP_ITERATIONS, BENCHMARK_ITERATIONS); + printf("\n"); + printf("%-30s %10s %10s %10s %10s %10s\n", + "Case", "K", "N", "CUTLASS", "cuBLASLt", "Speedup"); + printf("%-30s %10s %10s %10s %10s %10s\n", + "", "", "", "(us)", "(us)", ""); + printf("--------------------------------------------------------------------------------\n"); + + // Run benchmarks + for (const auto& test : BENCHMARK_CASES) { + BenchmarkResult result = run_benchmark(test.K, test.N, cublas); + + printf("%-30s %10d %10d %10.2f %10.2f %9.2fx %s\n", + test.name, + test.K, test.N, + result.cutlass_us, + result.cublaslt_us, + result.speedup, + result.speedup >= 1.0f ? "(CUTLASS wins)" : "(cuBLASLt wins)"); + + if (result.max_error > 0.01f) { + printf(" WARNING: Max error = %.6f\n", result.max_error); + } + } + + printf("\n"); + printf("================================================================================\n"); + printf("Analysis:\n"); + printf("================================================================================\n"); + printf("\n"); + printf("Performance gap causes (when cuBLASLt wins):\n"); + printf("1. cuBLASLt uses hand-tuned PTX/SASS assembly\n"); + printf("2. cuBLASLt may use specialized M=1 kernel paths\n"); + printf("3. cuBLASLt may use different memory access patterns (texture cache)\n"); + printf("4. Our UNROLL_K=8 may not be optimal for all K sizes\n"); + printf("\n"); + printf("Improvement opportunities for CUTLASS GEMV:\n"); + printf("1. Tune BLOCK_SIZE and UNROLL_K per (K, N) range\n"); + printf("2. Add shared memory tiling for A (reduces L2 pressure)\n"); + printf("3. Use vectorized BF16x2 or BF16x4 loads where aligned\n"); + printf("4. Add software pipelining (async copy + compute overlap)\n"); + printf("5. Consider warp specialization for very large K\n"); + printf("\n"); + printf("Future FP8/SM120 considerations:\n"); + printf("1. FP8 E4M3/E5M2 would require custom quantization\n"); + printf("2. SM120 lacks native FP8 GEMV support in CUTLASS 4.x\n"); + printf("3. BF16 fallback is the current solution for SM120\n"); + printf("4. When CUTLASS SM120 FP8 is fixed, add FP8 path\n"); + + return 0; +} diff --git a/native/ops/gemv/build_benchmark.bat b/native/ops/gemv/build_benchmark.bat new file mode 100644 index 0000000..d8ff0ae --- /dev/null +++ b/native/ops/gemv/build_benchmark.bat @@ -0,0 +1,52 @@ +@echo off +REM Build and run GEMV benchmark (vs cuBLASLt) +REM Run from Windows Command Prompt + +setlocal EnableDelayedExpansion + +REM Setup Visual Studio environment +call "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 +if errorlevel 1 ( + echo ERROR: Failed to setup Visual Studio environment + exit /b 1 +) + +REM Setup CUDA environment +if exist "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\bin\nvcc.exe" ( + set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1 + set SM_ARCH=120 +) else if exist "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9\bin\nvcc.exe" ( + set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9 + set SM_ARCH=86 +) else ( + echo ERROR: CUDA not found + exit /b 1 +) + +set PATH=%CUDA_PATH%\bin;%PATH% + +echo. +echo ============================================ +echo GEMV Benchmark Build +echo ============================================ +echo CUDA: %CUDA_PATH% +echo SM: %SM_ARCH% +echo. + +REM Change to script directory +cd /d %~dp0 + +REM Build benchmark (linking cuBLASLt) +echo Building benchmark_gemv.cu... +nvcc -std=c++17 -O3 -arch=sm_%SM_ARCH% benchmark_gemv.cu -lcublasLt -o benchmark_gemv.exe +if errorlevel 1 ( + echo ERROR: Build failed + exit /b 1 +) + +echo. +echo Running benchmark... +echo. +"%~dp0benchmark_gemv.exe" + +endlocal diff --git a/native/ops/gemv/build_test.bat b/native/ops/gemv/build_test.bat new file mode 100644 index 0000000..6a82e0d --- /dev/null +++ b/native/ops/gemv/build_test.bat @@ -0,0 +1,55 @@ +@echo off +REM Build and run GEMV tests +REM Run from Windows Command Prompt + +setlocal EnableDelayedExpansion + +REM Setup Visual Studio environment +call "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 +if errorlevel 1 ( + echo ERROR: Failed to setup Visual Studio environment + exit /b 1 +) + +REM Setup CUDA environment - try different versions +if exist "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\bin\nvcc.exe" ( + set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1 + set SM_ARCH=120 +) else if exist "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9\bin\nvcc.exe" ( + set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9 + set SM_ARCH=86 +) else if exist "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\nvcc.exe" ( + set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4 + set SM_ARCH=86 +) else ( + echo ERROR: CUDA not found + exit /b 1 +) + +set PATH=%CUDA_PATH%\bin;%PATH% + +echo. +echo ============================================ +echo GEMV Test Build +echo ============================================ +echo CUDA: %CUDA_PATH% +echo SM: %SM_ARCH% +echo. + +REM Change to script directory +cd /d %~dp0 + +REM Build test +echo Building test_gemv.cu... +nvcc -std=c++17 -O3 -arch=sm_%SM_ARCH% test_gemv.cu -o test_gemv.exe +if errorlevel 1 ( + echo ERROR: Build failed + exit /b 1 +) + +echo. +echo Running tests... +echo. +"%~dp0test_gemv.exe" + +endlocal diff --git a/native/ops/gemv/gemv_cutlass.cuh b/native/ops/gemv/gemv_cutlass.cuh new file mode 100644 index 0000000..bb4026d --- /dev/null +++ b/native/ops/gemv/gemv_cutlass.cuh @@ -0,0 +1,846 @@ +/** + * CUTLASS-inspired GEMV Kernel for M=1 (LLM Decode Path) + * + * Purpose: Replace cuBLASLt GEMV with CUTLASS-based implementation + * + * Design decisions: + * 1. M=1 is memory-bound, not compute-bound + * 2. TensorCore is inefficient for M=1 (MMA tiles are wasted) + * 3. Scalar FMA with vectorized loads is optimal + * 4. A[1,K] is small, broadcasts via L1/L2 cache + * 5. B[K,N] row-major: adjacent threads read adjacent addresses (coalesced) + * + * Target architectures: + * - SM86 (RTX 30xx): Primary target + * - SM89 (RTX 40xx): Supported + * - SM90 (H100): Supported + * - SM120 (RTX 5090): BF16 fallback + * + * Future extensions: + * - Batched GEMV for continuous batching + * - FP8 for SM90/SM120 when available + * - Fused bias/scale epilogue + */ + +#pragma once + +#include +#include +#include +#include + +namespace pygpukit { +namespace ops { +namespace gemv { + +// ============================================================================ +// Configuration - Per-size tuning +// ============================================================================ + +// Default configuration (medium sizes: K=2048-8192, N=1024-8192) +struct GemvConfig { + static constexpr int BLOCK_SIZE = 256; // 8 warps + static constexpr int TILE_N = 256; + static constexpr int UNROLL_K = 8; + static constexpr int MIN_N = 128; +}; + +// Small K configuration (K < 2048) +// - Smaller unroll to reduce register pressure +// - Good for embedding lookups, small hidden sizes +struct GemvConfigSmallK { + static constexpr int BLOCK_SIZE = 256; + static constexpr int TILE_N = 256; + static constexpr int UNROLL_K = 4; // Less unrolling for small K + static constexpr int MIN_N = 128; +}; + +// Large K configuration (K > 8192) +// - Larger unroll for more ILP +// - Trades registers for throughput +struct GemvConfigLargeK { + static constexpr int BLOCK_SIZE = 256; + static constexpr int TILE_N = 256; + static constexpr int UNROLL_K = 16; // More unrolling for large K + static constexpr int MIN_N = 128; +}; + +// Small N configuration (N < 1024) +// - Smaller tile to avoid wasted threads +// - Better for narrow outputs +struct GemvConfigSmallN { + static constexpr int BLOCK_SIZE = 128; // 4 warps + static constexpr int TILE_N = 128; + static constexpr int UNROLL_K = 8; + static constexpr int MIN_N = 64; +}; + +// Large matrices (K > 8192 AND N > 8192) +// - Maximum unrolling +// - Optimized for LLM MLP layers (8192x28672 etc) +struct GemvConfigLarge { + static constexpr int BLOCK_SIZE = 256; + static constexpr int TILE_N = 256; + static constexpr int UNROLL_K = 16; + static constexpr int MIN_N = 128; +}; + +// ============================================================================ +// Utility Functions +// ============================================================================ + +// Convert BF16 to FP32 with cache hint +__device__ __forceinline__ float ldg_bf16_to_f32(const __nv_bfloat16* ptr) { + return __bfloat162float(__ldg(ptr)); +} + +// Convert FP16 to FP32 with cache hint +__device__ __forceinline__ float ldg_fp16_to_f32(const __half* ptr) { + return __half2float(__ldg(ptr)); +} + +// Vectorized load: Load 2 BF16 values as bfloat162 +__device__ __forceinline__ __nv_bfloat162 ldg_bf16x2(const __nv_bfloat16* ptr) { + return __ldg(reinterpret_cast(ptr)); +} + +// Vectorized load: Load 4 BF16 values as 2x bfloat162 +__device__ __forceinline__ void ldg_bf16x4(const __nv_bfloat16* ptr, + __nv_bfloat162& v01, __nv_bfloat162& v23) { + const __nv_bfloat162* ptr2 = reinterpret_cast(ptr); + v01 = __ldg(ptr2); + v23 = __ldg(ptr2 + 1); +} + +// ============================================================================ +// BF16 GEMV Kernel +// ============================================================================ + +/** + * GEMV kernel for BF16: C[1,N] = alpha * A[1,K] @ B[K,N] + beta * C[1,N] + * + * Memory layout (all row-major): + * - A: [1, K] contiguous, small, broadcasts well + * - B: [K, N] row-major, B[k,n] at address k*N+n + * - C: [1, N] contiguous output + * + * Thread mapping: + * - Each thread handles one output element C[global_n] + * - All threads in block iterate over K together + * - Coalesced access: threads 0-255 read B[k, block_start:block_start+256] + * + * Optimization techniques: + * 1. __ldg() for read-only cache (B access) + * 2. A broadcast via L1/L2 (all threads read same A[k]) + * 3. FMA accumulation in FP32 for precision + * 4. K-loop unrolling (UNROLL_K=8) for ILP + * 5. Predicated loads for K remainder handling + * 6. Vectorized BF16x2 loads for A (reduces memory transactions) + */ +template +__global__ void gemv_bf16_kernel( + __nv_bfloat16 const* __restrict__ A, // [1, K] + __nv_bfloat16 const* __restrict__ B, // [K, N] + __nv_bfloat16* __restrict__ C, // [1, N] + int K, + int N, + float alpha, + float beta +) { + // Thread/block indexing + const int tid = threadIdx.x; + const int block_n = blockIdx.x * Config::TILE_N; + const int global_n = block_n + tid; + + // Bounds check for partial blocks at the end + if (global_n >= N) return; + + // Accumulator in FP32 for numerical precision + // cuBLASLt also uses FP32 accumulation for BF16 + float acc = 0.0f; + + // Base pointer for this thread's column of B + // B[k, global_n] = B[k * N + global_n] + const __nv_bfloat16* B_col = B + global_n; + + // Main K loop with UNROLL_K unrolling + // Rationale: Hides memory latency, increases ILP + int k = 0; + constexpr int UNROLL = Config::UNROLL_K; + + // Template-based unrolling: UNROLL_K can be 4, 8, or 16 + for (; k + UNROLL <= K; k += UNROLL) { + // UNROLL_K=4: Load 2 bfloat162 (4 values) + // UNROLL_K=8: Load 4 bfloat162 (8 values) + // UNROLL_K=16: Load 8 bfloat162 (16 values) + + if constexpr (UNROLL == 4) { + __nv_bfloat162 a01 = ldg_bf16x2(A + k + 0); + __nv_bfloat162 a23 = ldg_bf16x2(A + k + 2); + float a0 = __low2float(a01); + float a1 = __high2float(a01); + float a2 = __low2float(a23); + float a3 = __high2float(a23); + float b0 = ldg_bf16_to_f32(B_col + (k + 0) * N); + float b1 = ldg_bf16_to_f32(B_col + (k + 1) * N); + float b2 = ldg_bf16_to_f32(B_col + (k + 2) * N); + float b3 = ldg_bf16_to_f32(B_col + (k + 3) * N); + acc = fmaf(a0, b0, acc); + acc = fmaf(a1, b1, acc); + acc = fmaf(a2, b2, acc); + acc = fmaf(a3, b3, acc); + } else if constexpr (UNROLL == 8) { + __nv_bfloat162 a01 = ldg_bf16x2(A + k + 0); + __nv_bfloat162 a23 = ldg_bf16x2(A + k + 2); + __nv_bfloat162 a45 = ldg_bf16x2(A + k + 4); + __nv_bfloat162 a67 = ldg_bf16x2(A + k + 6); + float a0 = __low2float(a01); + float a1 = __high2float(a01); + float a2 = __low2float(a23); + float a3 = __high2float(a23); + float a4 = __low2float(a45); + float a5 = __high2float(a45); + float a6 = __low2float(a67); + float a7 = __high2float(a67); + float b0 = ldg_bf16_to_f32(B_col + (k + 0) * N); + float b1 = ldg_bf16_to_f32(B_col + (k + 1) * N); + float b2 = ldg_bf16_to_f32(B_col + (k + 2) * N); + float b3 = ldg_bf16_to_f32(B_col + (k + 3) * N); + float b4 = ldg_bf16_to_f32(B_col + (k + 4) * N); + float b5 = ldg_bf16_to_f32(B_col + (k + 5) * N); + float b6 = ldg_bf16_to_f32(B_col + (k + 6) * N); + float b7 = ldg_bf16_to_f32(B_col + (k + 7) * N); + acc = fmaf(a0, b0, acc); + acc = fmaf(a1, b1, acc); + acc = fmaf(a2, b2, acc); + acc = fmaf(a3, b3, acc); + acc = fmaf(a4, b4, acc); + acc = fmaf(a5, b5, acc); + acc = fmaf(a6, b6, acc); + acc = fmaf(a7, b7, acc); + } else if constexpr (UNROLL == 16) { + __nv_bfloat162 a01 = ldg_bf16x2(A + k + 0); + __nv_bfloat162 a23 = ldg_bf16x2(A + k + 2); + __nv_bfloat162 a45 = ldg_bf16x2(A + k + 4); + __nv_bfloat162 a67 = ldg_bf16x2(A + k + 6); + __nv_bfloat162 a89 = ldg_bf16x2(A + k + 8); + __nv_bfloat162 aAB = ldg_bf16x2(A + k + 10); + __nv_bfloat162 aCD = ldg_bf16x2(A + k + 12); + __nv_bfloat162 aEF = ldg_bf16x2(A + k + 14); + float a0 = __low2float(a01); + float a1 = __high2float(a01); + float a2 = __low2float(a23); + float a3 = __high2float(a23); + float a4 = __low2float(a45); + float a5 = __high2float(a45); + float a6 = __low2float(a67); + float a7 = __high2float(a67); + float a8 = __low2float(a89); + float a9 = __high2float(a89); + float aA = __low2float(aAB); + float aB = __high2float(aAB); + float aC = __low2float(aCD); + float aD = __high2float(aCD); + float aE = __low2float(aEF); + float aF = __high2float(aEF); + float b0 = ldg_bf16_to_f32(B_col + (k + 0) * N); + float b1 = ldg_bf16_to_f32(B_col + (k + 1) * N); + float b2 = ldg_bf16_to_f32(B_col + (k + 2) * N); + float b3 = ldg_bf16_to_f32(B_col + (k + 3) * N); + float b4 = ldg_bf16_to_f32(B_col + (k + 4) * N); + float b5 = ldg_bf16_to_f32(B_col + (k + 5) * N); + float b6 = ldg_bf16_to_f32(B_col + (k + 6) * N); + float b7 = ldg_bf16_to_f32(B_col + (k + 7) * N); + float b8 = ldg_bf16_to_f32(B_col + (k + 8) * N); + float b9 = ldg_bf16_to_f32(B_col + (k + 9) * N); + float bA = ldg_bf16_to_f32(B_col + (k + 10) * N); + float bB = ldg_bf16_to_f32(B_col + (k + 11) * N); + float bC = ldg_bf16_to_f32(B_col + (k + 12) * N); + float bD = ldg_bf16_to_f32(B_col + (k + 13) * N); + float bE = ldg_bf16_to_f32(B_col + (k + 14) * N); + float bF = ldg_bf16_to_f32(B_col + (k + 15) * N); + acc = fmaf(a0, b0, acc); + acc = fmaf(a1, b1, acc); + acc = fmaf(a2, b2, acc); + acc = fmaf(a3, b3, acc); + acc = fmaf(a4, b4, acc); + acc = fmaf(a5, b5, acc); + acc = fmaf(a6, b6, acc); + acc = fmaf(a7, b7, acc); + acc = fmaf(a8, b8, acc); + acc = fmaf(a9, b9, acc); + acc = fmaf(aA, bA, acc); + acc = fmaf(aB, bB, acc); + acc = fmaf(aC, bC, acc); + acc = fmaf(aD, bD, acc); + acc = fmaf(aE, bE, acc); + acc = fmaf(aF, bF, acc); + } + } + + // Handle K remainder (when K is not divisible by UNROLL_K) + for (; k < K; ++k) { + float a = __bfloat162float(A[k]); + float b = ldg_bf16_to_f32(B_col + k * N); + acc = fmaf(a, b, acc); + } + + // Epilogue: Apply alpha/beta scaling + // Matches cuBLASLt behavior: D = alpha * A @ B + beta * C + if (beta != 0.0f) { + float c_old = __bfloat162float(C[global_n]); + acc = fmaf(alpha, acc, beta * c_old); + } else { + acc *= alpha; + } + + // Store result + C[global_n] = __float2bfloat16(acc); +} + +// ============================================================================ +// FP16 GEMV Kernel +// ============================================================================ + +/** + * GEMV kernel for FP16: C[1,N] = alpha * A[1,K] @ B[K,N] + beta * C[1,N] + * Same design as BF16, using FP16 intrinsics + */ +template +__global__ void gemv_fp16_kernel( + __half const* __restrict__ A, + __half const* __restrict__ B, + __half* __restrict__ C, + int K, + int N, + float alpha, + float beta +) { + const int tid = threadIdx.x; + const int block_n = blockIdx.x * Config::TILE_N; + const int global_n = block_n + tid; + + if (global_n >= N) return; + + float acc = 0.0f; + const __half* B_col = B + global_n; + + int k = 0; + constexpr int UNROLL = Config::UNROLL_K; + + for (; k + UNROLL <= K; k += UNROLL) { + float a0 = __half2float(A[k + 0]); + float a1 = __half2float(A[k + 1]); + float a2 = __half2float(A[k + 2]); + float a3 = __half2float(A[k + 3]); + float a4 = __half2float(A[k + 4]); + float a5 = __half2float(A[k + 5]); + float a6 = __half2float(A[k + 6]); + float a7 = __half2float(A[k + 7]); + + float b0 = ldg_fp16_to_f32(B_col + (k + 0) * N); + float b1 = ldg_fp16_to_f32(B_col + (k + 1) * N); + float b2 = ldg_fp16_to_f32(B_col + (k + 2) * N); + float b3 = ldg_fp16_to_f32(B_col + (k + 3) * N); + float b4 = ldg_fp16_to_f32(B_col + (k + 4) * N); + float b5 = ldg_fp16_to_f32(B_col + (k + 5) * N); + float b6 = ldg_fp16_to_f32(B_col + (k + 6) * N); + float b7 = ldg_fp16_to_f32(B_col + (k + 7) * N); + + acc = fmaf(a0, b0, acc); + acc = fmaf(a1, b1, acc); + acc = fmaf(a2, b2, acc); + acc = fmaf(a3, b3, acc); + acc = fmaf(a4, b4, acc); + acc = fmaf(a5, b5, acc); + acc = fmaf(a6, b6, acc); + acc = fmaf(a7, b7, acc); + } + + for (; k < K; ++k) { + float a = __half2float(A[k]); + float b = ldg_fp16_to_f32(B_col + k * N); + acc = fmaf(a, b, acc); + } + + if (beta != 0.0f) { + float c_old = __half2float(C[global_n]); + acc = fmaf(alpha, acc, beta * c_old); + } else { + acc *= alpha; + } + + C[global_n] = __float2half(acc); +} + +// ============================================================================ +// TF32 GEMV Kernel (FP32 input, TF32-style accumulation) +// ============================================================================ + +/** + * GEMV kernel for FP32: C[1,N] = alpha * A[1,K] @ B[K,N] + beta * C[1,N] + * Uses FP32 accumulation (no TensorCore at M=1) + */ +template +__global__ void gemv_fp32_kernel( + float const* __restrict__ A, + float const* __restrict__ B, + float* __restrict__ C, + int K, + int N, + float alpha, + float beta +) { + const int tid = threadIdx.x; + const int block_n = blockIdx.x * Config::TILE_N; + const int global_n = block_n + tid; + + if (global_n >= N) return; + + float acc = 0.0f; + const float* B_col = B + global_n; + + int k = 0; + constexpr int UNROLL = Config::UNROLL_K; + + for (; k + UNROLL <= K; k += UNROLL) { + float a0 = A[k + 0]; + float a1 = A[k + 1]; + float a2 = A[k + 2]; + float a3 = A[k + 3]; + float a4 = A[k + 4]; + float a5 = A[k + 5]; + float a6 = A[k + 6]; + float a7 = A[k + 7]; + + float b0 = __ldg(B_col + (k + 0) * N); + float b1 = __ldg(B_col + (k + 1) * N); + float b2 = __ldg(B_col + (k + 2) * N); + float b3 = __ldg(B_col + (k + 3) * N); + float b4 = __ldg(B_col + (k + 4) * N); + float b5 = __ldg(B_col + (k + 5) * N); + float b6 = __ldg(B_col + (k + 6) * N); + float b7 = __ldg(B_col + (k + 7) * N); + + acc = fmaf(a0, b0, acc); + acc = fmaf(a1, b1, acc); + acc = fmaf(a2, b2, acc); + acc = fmaf(a3, b3, acc); + acc = fmaf(a4, b4, acc); + acc = fmaf(a5, b5, acc); + acc = fmaf(a6, b6, acc); + acc = fmaf(a7, b7, acc); + } + + for (; k < K; ++k) { + float a = A[k]; + float b = __ldg(B_col + k * N); + acc = fmaf(a, b, acc); + } + + if (beta != 0.0f) { + acc = fmaf(alpha, acc, beta * C[global_n]); + } else { + acc *= alpha; + } + + C[global_n] = acc; +} + +// ============================================================================ +// Batched GEMV Kernels (for continuous batching) +// ============================================================================ + +/** + * Batched GEMV: C[batch,1,N] = A[batch,1,K] @ B[K,N] + * B is shared across batches (weight matrix) + * A is different per batch (activations) + * + * Grid: (ceil(N/TILE_N), batch_count) + * Each block handles one (batch, tile_n) pair + */ +template +__global__ void gemv_bf16_batched_kernel( + __nv_bfloat16 const* __restrict__ A, // [batch, K] + __nv_bfloat16 const* __restrict__ B, // [K, N] shared + __nv_bfloat16* __restrict__ C, // [batch, N] + int K, + int N, + int batch_count, + float alpha, + float beta +) { + const int tid = threadIdx.x; + const int block_n = blockIdx.x * Config::TILE_N; + const int batch_idx = blockIdx.y; + const int global_n = block_n + tid; + + if (global_n >= N || batch_idx >= batch_count) return; + + // Batch-specific A and C pointers + const __nv_bfloat16* A_batch = A + batch_idx * K; + __nv_bfloat16* C_batch = C + batch_idx * N; + + float acc = 0.0f; + const __nv_bfloat16* B_col = B + global_n; + + int k = 0; + constexpr int UNROLL = Config::UNROLL_K; + + // Template-based unrolling: UNROLL_K can be 4, 8, or 16 + for (; k + UNROLL <= K; k += UNROLL) { + if constexpr (UNROLL == 4) { + __nv_bfloat162 a01 = ldg_bf16x2(A_batch + k + 0); + __nv_bfloat162 a23 = ldg_bf16x2(A_batch + k + 2); + float a0 = __low2float(a01); + float a1 = __high2float(a01); + float a2 = __low2float(a23); + float a3 = __high2float(a23); + float b0 = ldg_bf16_to_f32(B_col + (k + 0) * N); + float b1 = ldg_bf16_to_f32(B_col + (k + 1) * N); + float b2 = ldg_bf16_to_f32(B_col + (k + 2) * N); + float b3 = ldg_bf16_to_f32(B_col + (k + 3) * N); + acc = fmaf(a0, b0, acc); + acc = fmaf(a1, b1, acc); + acc = fmaf(a2, b2, acc); + acc = fmaf(a3, b3, acc); + } else if constexpr (UNROLL == 8) { + __nv_bfloat162 a01 = ldg_bf16x2(A_batch + k + 0); + __nv_bfloat162 a23 = ldg_bf16x2(A_batch + k + 2); + __nv_bfloat162 a45 = ldg_bf16x2(A_batch + k + 4); + __nv_bfloat162 a67 = ldg_bf16x2(A_batch + k + 6); + float a0 = __low2float(a01); + float a1 = __high2float(a01); + float a2 = __low2float(a23); + float a3 = __high2float(a23); + float a4 = __low2float(a45); + float a5 = __high2float(a45); + float a6 = __low2float(a67); + float a7 = __high2float(a67); + float b0 = ldg_bf16_to_f32(B_col + (k + 0) * N); + float b1 = ldg_bf16_to_f32(B_col + (k + 1) * N); + float b2 = ldg_bf16_to_f32(B_col + (k + 2) * N); + float b3 = ldg_bf16_to_f32(B_col + (k + 3) * N); + float b4 = ldg_bf16_to_f32(B_col + (k + 4) * N); + float b5 = ldg_bf16_to_f32(B_col + (k + 5) * N); + float b6 = ldg_bf16_to_f32(B_col + (k + 6) * N); + float b7 = ldg_bf16_to_f32(B_col + (k + 7) * N); + acc = fmaf(a0, b0, acc); + acc = fmaf(a1, b1, acc); + acc = fmaf(a2, b2, acc); + acc = fmaf(a3, b3, acc); + acc = fmaf(a4, b4, acc); + acc = fmaf(a5, b5, acc); + acc = fmaf(a6, b6, acc); + acc = fmaf(a7, b7, acc); + } else if constexpr (UNROLL == 16) { + __nv_bfloat162 a01 = ldg_bf16x2(A_batch + k + 0); + __nv_bfloat162 a23 = ldg_bf16x2(A_batch + k + 2); + __nv_bfloat162 a45 = ldg_bf16x2(A_batch + k + 4); + __nv_bfloat162 a67 = ldg_bf16x2(A_batch + k + 6); + __nv_bfloat162 a89 = ldg_bf16x2(A_batch + k + 8); + __nv_bfloat162 aAB = ldg_bf16x2(A_batch + k + 10); + __nv_bfloat162 aCD = ldg_bf16x2(A_batch + k + 12); + __nv_bfloat162 aEF = ldg_bf16x2(A_batch + k + 14); + float a0 = __low2float(a01); + float a1 = __high2float(a01); + float a2 = __low2float(a23); + float a3 = __high2float(a23); + float a4 = __low2float(a45); + float a5 = __high2float(a45); + float a6 = __low2float(a67); + float a7 = __high2float(a67); + float a8 = __low2float(a89); + float a9 = __high2float(a89); + float aA = __low2float(aAB); + float aB = __high2float(aAB); + float aC = __low2float(aCD); + float aD = __high2float(aCD); + float aE = __low2float(aEF); + float aF = __high2float(aEF); + float b0 = ldg_bf16_to_f32(B_col + (k + 0) * N); + float b1 = ldg_bf16_to_f32(B_col + (k + 1) * N); + float b2 = ldg_bf16_to_f32(B_col + (k + 2) * N); + float b3 = ldg_bf16_to_f32(B_col + (k + 3) * N); + float b4 = ldg_bf16_to_f32(B_col + (k + 4) * N); + float b5 = ldg_bf16_to_f32(B_col + (k + 5) * N); + float b6 = ldg_bf16_to_f32(B_col + (k + 6) * N); + float b7 = ldg_bf16_to_f32(B_col + (k + 7) * N); + float b8 = ldg_bf16_to_f32(B_col + (k + 8) * N); + float b9 = ldg_bf16_to_f32(B_col + (k + 9) * N); + float bA = ldg_bf16_to_f32(B_col + (k + 10) * N); + float bB = ldg_bf16_to_f32(B_col + (k + 11) * N); + float bC = ldg_bf16_to_f32(B_col + (k + 12) * N); + float bD = ldg_bf16_to_f32(B_col + (k + 13) * N); + float bE = ldg_bf16_to_f32(B_col + (k + 14) * N); + float bF = ldg_bf16_to_f32(B_col + (k + 15) * N); + acc = fmaf(a0, b0, acc); + acc = fmaf(a1, b1, acc); + acc = fmaf(a2, b2, acc); + acc = fmaf(a3, b3, acc); + acc = fmaf(a4, b4, acc); + acc = fmaf(a5, b5, acc); + acc = fmaf(a6, b6, acc); + acc = fmaf(a7, b7, acc); + acc = fmaf(a8, b8, acc); + acc = fmaf(a9, b9, acc); + acc = fmaf(aA, bA, acc); + acc = fmaf(aB, bB, acc); + acc = fmaf(aC, bC, acc); + acc = fmaf(aD, bD, acc); + acc = fmaf(aE, bE, acc); + acc = fmaf(aF, bF, acc); + } + } + + for (; k < K; ++k) { + float a = __bfloat162float(A_batch[k]); + float b = ldg_bf16_to_f32(B_col + k * N); + acc = fmaf(a, b, acc); + } + + if (beta != 0.0f) { + float c_old = __bfloat162float(C_batch[global_n]); + acc = fmaf(alpha, acc, beta * c_old); + } else { + acc *= alpha; + } + + C_batch[global_n] = __float2bfloat16(acc); +} + +// ============================================================================ +// Launch Functions +// ============================================================================ + +/** + * Launch BF16 GEMV with per-size configuration selection + * + * Configuration selection logic: + * - Small N (< 1024): Use smaller block/tile (GemvConfigSmallN) + * - Small K (< 2048): Use smaller unroll (GemvConfigSmallK) + * - Large K (> 8192) AND Large N (> 8192): Maximum unroll (GemvConfigLarge) + * - Large K (> 8192): Larger unroll (GemvConfigLargeK) + * - Default: Balanced configuration (GemvConfig) + */ +inline cudaError_t launch_gemv_bf16( + const __nv_bfloat16* A, + const __nv_bfloat16* B, + __nv_bfloat16* C, + int K, + int N, + float alpha = 1.0f, + float beta = 0.0f, + cudaStream_t stream = nullptr +) { + // Per-size configuration dispatch + if (N < 1024) { + // Small N: use smaller block to avoid wasted threads + using Config = GemvConfigSmallN; + dim3 block(Config::BLOCK_SIZE); + dim3 grid((N + Config::TILE_N - 1) / Config::TILE_N); + gemv_bf16_kernel<<>>( + A, B, C, K, N, alpha, beta); + } else if (K > 8192 && N > 8192) { + // Large matrices: maximum unrolling + using Config = GemvConfigLarge; + dim3 block(Config::BLOCK_SIZE); + dim3 grid((N + Config::TILE_N - 1) / Config::TILE_N); + gemv_bf16_kernel<<>>( + A, B, C, K, N, alpha, beta); + } else if (K > 8192) { + // Large K: more unrolling for ILP + using Config = GemvConfigLargeK; + dim3 block(Config::BLOCK_SIZE); + dim3 grid((N + Config::TILE_N - 1) / Config::TILE_N); + gemv_bf16_kernel<<>>( + A, B, C, K, N, alpha, beta); + } else if (K < 2048) { + // Small K: less unrolling + using Config = GemvConfigSmallK; + dim3 block(Config::BLOCK_SIZE); + dim3 grid((N + Config::TILE_N - 1) / Config::TILE_N); + gemv_bf16_kernel<<>>( + A, B, C, K, N, alpha, beta); + } else { + // Default: balanced configuration + using Config = GemvConfig; + dim3 block(Config::BLOCK_SIZE); + dim3 grid((N + Config::TILE_N - 1) / Config::TILE_N); + gemv_bf16_kernel<<>>( + A, B, C, K, N, alpha, beta); + } + + return cudaGetLastError(); +} + +/** + * Launch FP16 GEMV + */ +inline cudaError_t launch_gemv_fp16( + const __half* A, + const __half* B, + __half* C, + int K, + int N, + float alpha = 1.0f, + float beta = 0.0f, + cudaStream_t stream = nullptr +) { + using Config = GemvConfig; + + dim3 block(Config::BLOCK_SIZE); + dim3 grid((N + Config::TILE_N - 1) / Config::TILE_N); + + gemv_fp16_kernel<<>>( + A, B, C, K, N, alpha, beta + ); + + return cudaGetLastError(); +} + +/** + * Launch FP32 GEMV + */ +inline cudaError_t launch_gemv_fp32( + const float* A, + const float* B, + float* C, + int K, + int N, + float alpha = 1.0f, + float beta = 0.0f, + cudaStream_t stream = nullptr +) { + using Config = GemvConfig; + + dim3 block(Config::BLOCK_SIZE); + dim3 grid((N + Config::TILE_N - 1) / Config::TILE_N); + + gemv_fp32_kernel<<>>( + A, B, C, K, N, alpha, beta + ); + + return cudaGetLastError(); +} + +/** + * Launch batched BF16 GEMV with per-size configuration selection + */ +inline cudaError_t launch_gemv_bf16_batched( + const __nv_bfloat16* A, // [batch, K] + const __nv_bfloat16* B, // [K, N] + __nv_bfloat16* C, // [batch, N] + int K, + int N, + int batch_count, + float alpha = 1.0f, + float beta = 0.0f, + cudaStream_t stream = nullptr +) { + // Per-size configuration dispatch (same logic as non-batched) + if (N < 1024) { + using Config = GemvConfigSmallN; + dim3 block(Config::BLOCK_SIZE); + dim3 grid((N + Config::TILE_N - 1) / Config::TILE_N, batch_count); + gemv_bf16_batched_kernel<<>>( + A, B, C, K, N, batch_count, alpha, beta); + } else if (K > 8192 && N > 8192) { + using Config = GemvConfigLarge; + dim3 block(Config::BLOCK_SIZE); + dim3 grid((N + Config::TILE_N - 1) / Config::TILE_N, batch_count); + gemv_bf16_batched_kernel<<>>( + A, B, C, K, N, batch_count, alpha, beta); + } else if (K > 8192) { + using Config = GemvConfigLargeK; + dim3 block(Config::BLOCK_SIZE); + dim3 grid((N + Config::TILE_N - 1) / Config::TILE_N, batch_count); + gemv_bf16_batched_kernel<<>>( + A, B, C, K, N, batch_count, alpha, beta); + } else if (K < 2048) { + using Config = GemvConfigSmallK; + dim3 block(Config::BLOCK_SIZE); + dim3 grid((N + Config::TILE_N - 1) / Config::TILE_N, batch_count); + gemv_bf16_batched_kernel<<>>( + A, B, C, K, N, batch_count, alpha, beta); + } else { + using Config = GemvConfig; + dim3 block(Config::BLOCK_SIZE); + dim3 grid((N + Config::TILE_N - 1) / Config::TILE_N, batch_count); + gemv_bf16_batched_kernel<<>>( + A, B, C, K, N, batch_count, alpha, beta); + } + + return cudaGetLastError(); +} + +// ============================================================================ +// Dispatch Function (M=1 detection) +// ============================================================================ + +/** + * GEMM/GEMV dispatcher + * + * Selects GEMV kernel when M=1, otherwise falls through to GEMM + * Returns true if GEMV was dispatched, false if GEMM should be used + */ +inline bool dispatch_gemv_bf16( + const __nv_bfloat16* A, + const __nv_bfloat16* B, + __nv_bfloat16* C, + int M, + int N, + int K, + float alpha = 1.0f, + float beta = 0.0f, + cudaStream_t stream = nullptr +) { + // GEMV dispatch conditions: + // 1. M == 1 (single row) + // 2. N >= MIN_N (avoid overhead for tiny outputs) + if (M == 1 && N >= GemvConfig::MIN_N) { + launch_gemv_bf16(A, B, C, K, N, alpha, beta, stream); + return true; + } + return false; +} + +inline bool dispatch_gemv_fp16( + const __half* A, + const __half* B, + __half* C, + int M, + int N, + int K, + float alpha = 1.0f, + float beta = 0.0f, + cudaStream_t stream = nullptr +) { + if (M == 1 && N >= GemvConfig::MIN_N) { + launch_gemv_fp16(A, B, C, K, N, alpha, beta, stream); + return true; + } + return false; +} + +inline bool dispatch_gemv_fp32( + const float* A, + const float* B, + float* C, + int M, + int N, + int K, + float alpha = 1.0f, + float beta = 0.0f, + cudaStream_t stream = nullptr +) { + if (M == 1 && N >= GemvConfig::MIN_N) { + launch_gemv_fp32(A, B, C, K, N, alpha, beta, stream); + return true; + } + return false; +} + +} // namespace gemv +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/gemv/gemv_nvf4.cu b/native/ops/gemv/gemv_nvf4.cu new file mode 100644 index 0000000..4ecb603 --- /dev/null +++ b/native/ops/gemv/gemv_nvf4.cu @@ -0,0 +1,218 @@ +/** + * NVF4 GEMV Implementation for SM120 with BF16 I/O + * + * This file provides: + * 1. NVF4 GEMV kernel dispatch + * 2. BF16 -> NVF4 weight quantization + * 3. Automatic dispatch based on GPU architecture + */ + +#include +#include +#include + +// Include both BF16 and NVF4 GEMV kernels +#include "gemv_cutlass.cuh" +#include "gemv_nvf4_sm120.cuh" + +namespace pygpukit { +namespace ops { +namespace gemv_dispatch { + +// ============================================================================ +// GPU Architecture Detection +// ============================================================================ + +static int cached_sm_version = -1; + +inline int get_sm_version() { + if (cached_sm_version < 0) { + int device_id = 0; + cudaGetDevice(&device_id); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device_id); + cached_sm_version = props.major * 10 + props.minor; + } + return cached_sm_version; +} + +inline bool is_sm120() { + int sm = get_sm_version(); + return (sm == 120 || sm == 121); +} + +// ============================================================================ +// NVF4 Weight Storage +// ============================================================================ + +/** + * Container for NVF4-quantized weights + */ +struct NVF4Weights { + uint8_t* data; // [K/2, N] packed NVF4 + uint8_t* scale; // [K/32, N] scale factors + int K; + int N; + bool owns_memory; + + NVF4Weights() : data(nullptr), scale(nullptr), K(0), N(0), owns_memory(false) {} + + ~NVF4Weights() { + if (owns_memory) { + if (data) cudaFree(data); + if (scale) cudaFree(scale); + } + } + + // Calculate memory sizes + size_t data_size() const { return (K / 2) * N; } + size_t scale_size() const { return ((K + 31) / 32) * N; } + size_t total_size() const { return data_size() + scale_size(); } + + // Memory savings vs BF16 + float compression_ratio() const { + size_t bf16_size = K * N * 2; // 2 bytes per BF16 + return (float)bf16_size / total_size(); + } +}; + +// ============================================================================ +// Exported Functions +// ============================================================================ + +} // namespace gemv_dispatch +} // namespace ops +} // namespace pygpukit + +// ============================================================================ +// C API for Python Bindings +// ============================================================================ + +extern "C" { + +/** + * Check if NVF4 GEMV is available + */ +bool pygpukit_gemv_nvf4_available() { + return pygpukit::ops::gemv_nvf4::is_available(); +} + +/** + * Quantize BF16 weights to NVF4 format + * + * @param input [K, N] BF16 row-major + * @param out_data [K/2, N] packed NVF4 (pre-allocated) + * @param out_scale [K/32, N] scale factors (pre-allocated) + * @param K Inner dimension + * @param N Output dimension + */ +cudaError_t pygpukit_quantize_bf16_to_nvf4( + const void* input, + void* out_data, + void* out_scale, + int K, + int N, + cudaStream_t stream +) { + return pygpukit::ops::gemv_nvf4::quantize_bf16_to_nvf4( + static_cast(input), + static_cast(out_data), + static_cast(out_scale), + K, N, stream + ); +} + +/** + * NVF4 GEMV: C[1,N] = A[1,K] @ B[K,N] (NVF4 quantized) + * + * @param A [K] BF16 input vector + * @param B_data [K/2, N] packed NVF4 weights + * @param B_scale [K/32, N] scale factors + * @param C [N] BF16 output vector + * @param K Inner dimension + * @param N Output dimension + * @param alpha Scaling factor + */ +cudaError_t pygpukit_gemv_nvf4_bf16( + const void* A, + const void* B_data, + const void* B_scale, + void* C, + int K, + int N, + float alpha, + cudaStream_t stream +) { + return pygpukit::ops::gemv_nvf4::launch_gemv_nvf4_bf16( + static_cast(A), + static_cast(B_data), + static_cast(B_scale), + static_cast<__nv_bfloat16*>(C), + K, N, alpha, stream + ); +} + +/** + * BF16 GEMV (standard, no quantization) + */ +cudaError_t pygpukit_gemv_bf16( + const void* A, + const void* B, + void* C, + int K, + int N, + float alpha, + float beta, + cudaStream_t stream +) { + return pygpukit::ops::gemv::launch_gemv_bf16( + static_cast(A), + static_cast(B), + static_cast<__nv_bfloat16*>(C), + K, N, alpha, beta, stream + ); +} + +/** + * Auto-dispatch GEMV: Uses NVF4 on SM120 if weights are pre-quantized + * Falls back to BF16 GEMV otherwise + */ +cudaError_t pygpukit_gemv_bf16_auto( + const void* A, + const void* B, + void* C, + int M, + int N, + int K, + float alpha, + float beta, + cudaStream_t stream +) { + // Only dispatch GEMV for M=1 + if (M != 1) { + return cudaErrorInvalidValue; // Use GEMM instead + } + + // Use standard BF16 GEMV (NVF4 requires pre-quantized weights) + return pygpukit::ops::gemv::launch_gemv_bf16( + static_cast(A), + static_cast(B), + static_cast<__nv_bfloat16*>(C), + K, N, alpha, beta, stream + ); +} + +/** + * Get memory sizes for NVF4 quantization + */ +void pygpukit_nvf4_get_sizes( + int K, + int N, + size_t* data_size, + size_t* scale_size +) { + *data_size = (K / 2) * N; + *scale_size = ((K + 31) / 32) * N; +} + +} // extern "C" diff --git a/native/ops/gemv/gemv_nvf4_sm120.cuh b/native/ops/gemv/gemv_nvf4_sm120.cuh new file mode 100644 index 0000000..3debbcf --- /dev/null +++ b/native/ops/gemv/gemv_nvf4_sm120.cuh @@ -0,0 +1,630 @@ +/** + * NVF4 GEMV Kernel for SM120 (Blackwell GeForce) with BF16 I/O + * + * Purpose: Memory-efficient GEMV for LLM inference decode path + * + * Data flow: + * A[1,K] (BF16) x B[K,N] (NVF4 + scale) -> C[1,N] (BF16) + * + * NVF4 (float_e2m1_t) format: + * - 4-bit per element (2 elements per byte) + * - Values: 0, +/-0.5, +/-1, +/-1.5, +/-2, +/-3, +/-4, +/-6 + * - Block scaling: 32 elements share one scale factor (float_ue4m3_t) + * + * Memory layout: + * - B_data: [K, N/2] packed NVF4 (column-major for coalesced access) + * - B_scale: [K/32, N] scale factors (one per 32-element block along K) + * + * Advantages over BF16 GEMV: + * - 4x less memory bandwidth for weights + * - Better cache utilization + * - Ideal for memory-bound M=1 decode + */ + +#pragma once + +#include +#include +#include + +namespace pygpukit { +namespace ops { +namespace gemv_nvf4 { + +// ============================================================================ +// NVF4 Dequantization +// ============================================================================ + +// NVF4 E2M1 lookup table (4-bit -> float) +// Index 0-7: positive values, 8-15: negative values +__device__ __constant__ float NVF4_LUT[16] = { + 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, // 0-7: positive + 0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f // 8-15: negative (sign bit) +}; + +// Dequantize NVF4 value using lookup table +__device__ __forceinline__ float dequant_nvf4(uint8_t nvf4_val) { + return NVF4_LUT[nvf4_val & 0x0F]; +} + +// Dequantize packed byte (2 NVF4 values) and apply scale +__device__ __forceinline__ void dequant_nvf4x2( + uint8_t packed, + float scale, + float& out0, + float& out1 +) { + out0 = NVF4_LUT[packed & 0x0F] * scale; + out1 = NVF4_LUT[(packed >> 4) & 0x0F] * scale; +} + +// UE4M3 scale factor lookup table (256 entries for direct byte indexing) +// UE4M3: 4-bit unsigned exponent (bits 3-6), 3-bit mantissa (bits 0-2) +// Value = (1 + mantissa/8) * 2^(exponent - 7) +// Note: bit 7 is unused, so entries 128-255 mirror 0-127 +__device__ __constant__ float UE4M3_SCALE_LUT[256] = { + // exp=0: 2^(-7) = 0.0078125 + 0.0078125f, 0.0087890625f, 0.009765625f, 0.0107421875f, 0.01171875f, 0.0126953125f, 0.013671875f, 0.0146484375f, + // exp=1: 2^(-6) = 0.015625 + 0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f, + // exp=2: 2^(-5) = 0.03125 + 0.03125f, 0.03515625f, 0.0390625f, 0.04296875f, 0.046875f, 0.05078125f, 0.0546875f, 0.05859375f, + // exp=3: 2^(-4) = 0.0625 + 0.0625f, 0.0703125f, 0.078125f, 0.0859375f, 0.09375f, 0.1015625f, 0.109375f, 0.1171875f, + // exp=4: 2^(-3) = 0.125 + 0.125f, 0.140625f, 0.15625f, 0.171875f, 0.1875f, 0.203125f, 0.21875f, 0.234375f, + // exp=5: 2^(-2) = 0.25 + 0.25f, 0.28125f, 0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f, 0.46875f, + // exp=6: 2^(-1) = 0.5 + 0.5f, 0.5625f, 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f, 0.9375f, + // exp=7: 2^0 = 1.0 + 1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, + // exp=8: 2^1 = 2.0 + 2.0f, 2.25f, 2.5f, 2.75f, 3.0f, 3.25f, 3.5f, 3.75f, + // exp=9: 2^2 = 4.0 + 4.0f, 4.5f, 5.0f, 5.5f, 6.0f, 6.5f, 7.0f, 7.5f, + // exp=10: 2^3 = 8.0 + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + // exp=11: 2^4 = 16.0 + 16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f, + // exp=12: 2^5 = 32.0 + 32.0f, 36.0f, 40.0f, 44.0f, 48.0f, 52.0f, 56.0f, 60.0f, + // exp=13: 2^6 = 64.0 + 64.0f, 72.0f, 80.0f, 88.0f, 96.0f, 104.0f, 112.0f, 120.0f, + // exp=14: 2^7 = 128.0 + 128.0f, 144.0f, 160.0f, 176.0f, 192.0f, 208.0f, 224.0f, 240.0f, + // exp=15: 2^8 = 256.0 + 256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, 480.0f, + // Mirror for bit 7 set (128-255) + 0.0078125f, 0.0087890625f, 0.009765625f, 0.0107421875f, 0.01171875f, 0.0126953125f, 0.013671875f, 0.0146484375f, + 0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f, + 0.03125f, 0.03515625f, 0.0390625f, 0.04296875f, 0.046875f, 0.05078125f, 0.0546875f, 0.05859375f, + 0.0625f, 0.0703125f, 0.078125f, 0.0859375f, 0.09375f, 0.1015625f, 0.109375f, 0.1171875f, + 0.125f, 0.140625f, 0.15625f, 0.171875f, 0.1875f, 0.203125f, 0.21875f, 0.234375f, + 0.25f, 0.28125f, 0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f, 0.46875f, + 0.5f, 0.5625f, 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f, 0.9375f, + 1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, + 2.0f, 2.25f, 2.5f, 2.75f, 3.0f, 3.25f, 3.5f, 3.75f, + 4.0f, 4.5f, 5.0f, 5.5f, 6.0f, 6.5f, 7.0f, 7.5f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f, + 32.0f, 36.0f, 40.0f, 44.0f, 48.0f, 52.0f, 56.0f, 60.0f, + 64.0f, 72.0f, 80.0f, 88.0f, 96.0f, 104.0f, 112.0f, 120.0f, + 128.0f, 144.0f, 160.0f, 176.0f, 192.0f, 208.0f, 224.0f, 240.0f, + 256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, 480.0f, +}; + +// Fast UE4M3 scale decode using LUT (single memory access) +__device__ __forceinline__ float decode_ue4m3_scale(uint8_t ue4m3) { + return UE4M3_SCALE_LUT[ue4m3]; +} + +// ============================================================================ +// Configuration +// ============================================================================ + +struct GemvNvf4Config { + static constexpr int BLOCK_SIZE = 256; // Threads per block + static constexpr int TILE_N = 256; // Output elements per block + static constexpr int UNROLL_K = 8; // K-loop unrolling (must be multiple of 2) + static constexpr int SCALE_BLOCK = 32; // Elements per scale factor +}; + +// ============================================================================ +// NVF4 GEMV Kernel +// ============================================================================ + +/** + * GEMV kernel: C[1,N] = A[1,K] @ B[K,N] where B is NVF4 quantized + * + * Memory layout: + * - A: [K] BF16 contiguous (input vector) + * - B_data: [K/2, N] packed NVF4 (2 elements per byte, row-major) + * B_data[k/2, n] contains B[k, n] (low nibble) and B[k+1, n] (high nibble) + * - B_scale: [K/32, N] UE4M3 scale factors + * - C: [N] BF16 output + */ +template +__global__ void gemv_nvf4_bf16_kernel( + __nv_bfloat16 const* __restrict__ A, // [K] BF16 + uint8_t const* __restrict__ B_data, // [K/2, N] packed NVF4 + uint8_t const* __restrict__ B_scale, // [K/32, N] UE4M3 scales + __nv_bfloat16* __restrict__ C, // [N] BF16 output + int K, + int N, + float alpha +) { + const int tid = threadIdx.x; + const int block_n = blockIdx.x * Config::TILE_N; + const int global_n = block_n + tid; + + if (global_n >= N) return; + + float acc = 0.0f; + + // Base pointers for this thread's column + const uint8_t* B_col = B_data + global_n; // B_data[0, global_n] + const uint8_t* S_col = B_scale + global_n; // B_scale[0, global_n] + + const int K_packed = K / 2; // Packed dimension + const int num_scale_blocks = (K + Config::SCALE_BLOCK - 1) / Config::SCALE_BLOCK; + + // Process in scale blocks (32 elements = 16 packed bytes per block) + for (int sb = 0; sb < num_scale_blocks; ++sb) { + // Load scale factor for this block + float scale = decode_ue4m3_scale(__ldg(S_col + sb * N)); + + int k_start = sb * Config::SCALE_BLOCK; + int k_end = min(k_start + Config::SCALE_BLOCK, K); + + // Process pairs (2 NVF4 values per byte) + for (int k = k_start; k < k_end; k += 2) { + int k_packed = k / 2; + + // Load packed NVF4 byte + uint8_t packed = __ldg(B_col + k_packed * N); + + // Dequantize + float b0, b1; + dequant_nvf4x2(packed, scale, b0, b1); + + // Load A values + float a0 = __bfloat162float(A[k]); + float a1 = (k + 1 < K) ? __bfloat162float(A[k + 1]) : 0.0f; + + // Accumulate + acc = fmaf(a0, b0, acc); + acc = fmaf(a1, b1, acc); + } + } + + // Apply alpha and store + C[global_n] = __float2bfloat16(alpha * acc); +} + +/** + * Optimized kernel with register-cached scaled LUT + * + * Key optimization: + * - Pre-compute scaled LUT values once per scale block (16 regs) + * - Eliminates per-value multiply by scale + * - Unrolled inner loop for ILP + */ +template +__global__ void gemv_nvf4_bf16_kernel_unrolled( + __nv_bfloat16 const* __restrict__ A, + uint8_t const* __restrict__ B_data, + uint8_t const* __restrict__ B_scale, + __nv_bfloat16* __restrict__ C, + int K, + int N, + float alpha +) { + const int tid = threadIdx.x; + const int block_n = blockIdx.x * Config::TILE_N; + const int global_n = block_n + tid; + + if (global_n >= N) return; + + float acc = 0.0f; + + const uint8_t* B_col = B_data + global_n; + const uint8_t* S_col = B_scale + global_n; + + const int num_scale_blocks = K / Config::SCALE_BLOCK; + const int K_remainder = K % Config::SCALE_BLOCK; + + // Main loop: process complete scale blocks + for (int sb = 0; sb < num_scale_blocks; ++sb) { + int k_base = sb * Config::SCALE_BLOCK; + + // Load and decode scale factor + float scale = decode_ue4m3_scale(__ldg(S_col + sb * N)); + + // Pre-compute scaled LUT in registers (16 values) + // This eliminates 32 multiplies per scale block (saves 16 net) + float lut0 = 0.0f; // NVF4_LUT[0] * scale + float lut1 = 0.5f * scale; // NVF4_LUT[1] * scale + float lut2 = 1.0f * scale; // NVF4_LUT[2] * scale + float lut3 = 1.5f * scale; // NVF4_LUT[3] * scale + float lut4 = 2.0f * scale; // NVF4_LUT[4] * scale + float lut5 = 3.0f * scale; // NVF4_LUT[5] * scale + float lut6 = 4.0f * scale; // NVF4_LUT[6] * scale + float lut7 = 6.0f * scale; // NVF4_LUT[7] * scale + float lut8 = 0.0f; // NVF4_LUT[8] * scale (neg zero) + float lut9 = -0.5f * scale; // NVF4_LUT[9] * scale + float lut10 = -1.0f * scale; // NVF4_LUT[10] * scale + float lut11 = -1.5f * scale; // NVF4_LUT[11] * scale + float lut12 = -2.0f * scale; // NVF4_LUT[12] * scale + float lut13 = -3.0f * scale; // NVF4_LUT[13] * scale + float lut14 = -4.0f * scale; // NVF4_LUT[14] * scale + float lut15 = -6.0f * scale; // NVF4_LUT[15] * scale + + // Pack into array for indexed access + float scaled_lut[16] = { + lut0, lut1, lut2, lut3, lut4, lut5, lut6, lut7, + lut8, lut9, lut10, lut11, lut12, lut13, lut14, lut15 + }; + + int k_packed_base = k_base / 2; + + // Process 32 elements (16 packed bytes) with full unroll + #pragma unroll + for (int i = 0; i < 16; i += 4) { + // Load 4 packed bytes + uint8_t p0 = __ldg(B_col + (k_packed_base + i + 0) * N); + uint8_t p1 = __ldg(B_col + (k_packed_base + i + 1) * N); + uint8_t p2 = __ldg(B_col + (k_packed_base + i + 2) * N); + uint8_t p3 = __ldg(B_col + (k_packed_base + i + 3) * N); + + // Dequantize using pre-scaled LUT (no per-value multiply) + float b0 = scaled_lut[p0 & 0x0F]; + float b1 = scaled_lut[(p0 >> 4) & 0x0F]; + float b2 = scaled_lut[p1 & 0x0F]; + float b3 = scaled_lut[(p1 >> 4) & 0x0F]; + float b4 = scaled_lut[p2 & 0x0F]; + float b5 = scaled_lut[(p2 >> 4) & 0x0F]; + float b6 = scaled_lut[p3 & 0x0F]; + float b7 = scaled_lut[(p3 >> 4) & 0x0F]; + + // Load A values (L1 cache should hit well) + int a_idx = k_base + i * 2; + float a0 = __bfloat162float(A[a_idx + 0]); + float a1 = __bfloat162float(A[a_idx + 1]); + float a2 = __bfloat162float(A[a_idx + 2]); + float a3 = __bfloat162float(A[a_idx + 3]); + float a4 = __bfloat162float(A[a_idx + 4]); + float a5 = __bfloat162float(A[a_idx + 5]); + float a6 = __bfloat162float(A[a_idx + 6]); + float a7 = __bfloat162float(A[a_idx + 7]); + + // Accumulate with FMA + acc = fmaf(a0, b0, acc); + acc = fmaf(a1, b1, acc); + acc = fmaf(a2, b2, acc); + acc = fmaf(a3, b3, acc); + acc = fmaf(a4, b4, acc); + acc = fmaf(a5, b5, acc); + acc = fmaf(a6, b6, acc); + acc = fmaf(a7, b7, acc); + } + } + + // Handle remainder (if K is not multiple of SCALE_BLOCK) + if (K_remainder > 0) { + int sb = num_scale_blocks; + int k_base = sb * Config::SCALE_BLOCK; + + float scale = decode_ue4m3_scale(__ldg(S_col + sb * N)); + + for (int k = 0; k < K_remainder; k += 2) { + int k_packed = (k_base + k) / 2; + uint8_t packed = __ldg(B_col + k_packed * N); + + float b0 = NVF4_LUT[packed & 0x0F] * scale; + float b1 = NVF4_LUT[(packed >> 4) & 0x0F] * scale; + + float a0 = __bfloat162float(A[k_base + k]); + float a1 = (k + 1 < K_remainder) ? __bfloat162float(A[k_base + k + 1]) : 0.0f; + + acc = fmaf(a0, b0, acc); + acc = fmaf(a1, b1, acc); + } + } + + C[global_n] = __float2bfloat16(alpha * acc); +} + +/** + * Optimized kernel with 2 outputs per thread + * + * Key optimization: + * - Each thread computes 2 output columns + * - A vector loads shared between both columns + * - Higher arithmetic intensity, better ILP + */ +template +__global__ void gemv_nvf4_bf16_kernel_multi( + __nv_bfloat16 const* __restrict__ A, + uint8_t const* __restrict__ B_data, + uint8_t const* __restrict__ B_scale, + __nv_bfloat16* __restrict__ C, + int K, + int N, + float alpha +) { + const int tid = threadIdx.x; + const int block_n = blockIdx.x * Config::TILE_N * COLS_PER_THREAD; + const int global_n0 = block_n + tid; + const int global_n1 = global_n0 + Config::TILE_N; + + const bool valid0 = (global_n0 < N); + const bool valid1 = (global_n1 < N); + + if (!valid0 && !valid1) return; + + float acc0 = 0.0f; + float acc1 = 0.0f; + + const uint8_t* B_col0 = B_data + global_n0; + const uint8_t* B_col1 = B_data + global_n1; + const uint8_t* S_col0 = B_scale + global_n0; + const uint8_t* S_col1 = B_scale + global_n1; + + const int num_scale_blocks = K / Config::SCALE_BLOCK; + + // Main loop: process complete scale blocks + for (int sb = 0; sb < num_scale_blocks; ++sb) { + int k_base = sb * Config::SCALE_BLOCK; + + // Load scales for both columns + float scale0 = valid0 ? decode_ue4m3_scale(__ldg(S_col0 + sb * N)) : 0.0f; + float scale1 = valid1 ? decode_ue4m3_scale(__ldg(S_col1 + sb * N)) : 0.0f; + + int k_packed_base = k_base / 2; + + // Process 32 elements (16 packed bytes) with full unroll + #pragma unroll + for (int i = 0; i < 16; i += 4) { + // Load A values once (shared between both columns) + int a_idx = k_base + i * 2; + float a0 = __bfloat162float(A[a_idx + 0]); + float a1 = __bfloat162float(A[a_idx + 1]); + float a2 = __bfloat162float(A[a_idx + 2]); + float a3 = __bfloat162float(A[a_idx + 3]); + float a4 = __bfloat162float(A[a_idx + 4]); + float a5 = __bfloat162float(A[a_idx + 5]); + float a6 = __bfloat162float(A[a_idx + 6]); + float a7 = __bfloat162float(A[a_idx + 7]); + + // Process column 0 + if (valid0) { + uint8_t p0 = __ldg(B_col0 + (k_packed_base + i + 0) * N); + uint8_t p1 = __ldg(B_col0 + (k_packed_base + i + 1) * N); + uint8_t p2 = __ldg(B_col0 + (k_packed_base + i + 2) * N); + uint8_t p3 = __ldg(B_col0 + (k_packed_base + i + 3) * N); + + acc0 = fmaf(a0, NVF4_LUT[p0 & 0x0F] * scale0, acc0); + acc0 = fmaf(a1, NVF4_LUT[(p0 >> 4) & 0x0F] * scale0, acc0); + acc0 = fmaf(a2, NVF4_LUT[p1 & 0x0F] * scale0, acc0); + acc0 = fmaf(a3, NVF4_LUT[(p1 >> 4) & 0x0F] * scale0, acc0); + acc0 = fmaf(a4, NVF4_LUT[p2 & 0x0F] * scale0, acc0); + acc0 = fmaf(a5, NVF4_LUT[(p2 >> 4) & 0x0F] * scale0, acc0); + acc0 = fmaf(a6, NVF4_LUT[p3 & 0x0F] * scale0, acc0); + acc0 = fmaf(a7, NVF4_LUT[(p3 >> 4) & 0x0F] * scale0, acc0); + } + + // Process column 1 + if (valid1) { + uint8_t p0 = __ldg(B_col1 + (k_packed_base + i + 0) * N); + uint8_t p1 = __ldg(B_col1 + (k_packed_base + i + 1) * N); + uint8_t p2 = __ldg(B_col1 + (k_packed_base + i + 2) * N); + uint8_t p3 = __ldg(B_col1 + (k_packed_base + i + 3) * N); + + acc1 = fmaf(a0, NVF4_LUT[p0 & 0x0F] * scale1, acc1); + acc1 = fmaf(a1, NVF4_LUT[(p0 >> 4) & 0x0F] * scale1, acc1); + acc1 = fmaf(a2, NVF4_LUT[p1 & 0x0F] * scale1, acc1); + acc1 = fmaf(a3, NVF4_LUT[(p1 >> 4) & 0x0F] * scale1, acc1); + acc1 = fmaf(a4, NVF4_LUT[p2 & 0x0F] * scale1, acc1); + acc1 = fmaf(a5, NVF4_LUT[(p2 >> 4) & 0x0F] * scale1, acc1); + acc1 = fmaf(a6, NVF4_LUT[p3 & 0x0F] * scale1, acc1); + acc1 = fmaf(a7, NVF4_LUT[(p3 >> 4) & 0x0F] * scale1, acc1); + } + } + } + + // Store results + if (valid0) C[global_n0] = __float2bfloat16(alpha * acc0); + if (valid1) C[global_n1] = __float2bfloat16(alpha * acc1); +} + +// ============================================================================ +// Launch Functions +// ============================================================================ + +/** + * Launch NVF4 GEMV + * + * @param A Input vector [K] BF16 + * @param B_data Weight matrix [K/2, N] packed NVF4 + * @param B_scale Scale factors [K/32, N] UE4M3 + * @param C Output vector [N] BF16 + * @param K Inner dimension + * @param N Output dimension + * @param alpha Scaling factor (default 1.0) + * @param stream CUDA stream + */ +inline cudaError_t launch_gemv_nvf4_bf16( + const __nv_bfloat16* A, + const uint8_t* B_data, + const uint8_t* B_scale, + __nv_bfloat16* C, + int K, + int N, + float alpha = 1.0f, + cudaStream_t stream = nullptr +) { + using Config = GemvNvf4Config; + + dim3 block(Config::BLOCK_SIZE); + dim3 grid((N + Config::TILE_N - 1) / Config::TILE_N); + + // Use unrolled kernel for aligned K + if (K % Config::SCALE_BLOCK == 0 && K >= Config::SCALE_BLOCK) { + gemv_nvf4_bf16_kernel_unrolled<<>>( + A, B_data, B_scale, C, K, N, alpha + ); + } else { + gemv_nvf4_bf16_kernel<<>>( + A, B_data, B_scale, C, K, N, alpha + ); + } + + return cudaGetLastError(); +} + +// ============================================================================ +// Quantization Kernel (BF16 -> NVF4) +// ============================================================================ + +/** + * Quantize BF16 matrix to NVF4 with block scaling + * + * Input: B[K, N] BF16 row-major + * Output: B_data[K/2, N] packed NVF4 + * B_scale[K/32, N] UE4M3 scale factors + */ +__global__ void quantize_bf16_to_nvf4_kernel( + __nv_bfloat16 const* __restrict__ input, // [K, N] row-major + uint8_t* __restrict__ output_data, // [K/2, N] packed NVF4 + uint8_t* __restrict__ output_scale, // [K/32, N] scale factors + int K, + int N +) { + const int n = blockIdx.x * blockDim.x + threadIdx.x; + const int scale_block = blockIdx.y; + + if (n >= N) return; + + const int SCALE_BLOCK = 32; + const int k_start = scale_block * SCALE_BLOCK; + const int k_end = min(k_start + SCALE_BLOCK, K); + + // Find max absolute value in block + float max_abs = 0.0f; + for (int k = k_start; k < k_end; ++k) { + float val = fabsf(__bfloat162float(input[k * N + n])); + max_abs = fmaxf(max_abs, val); + } + + // Compute scale factor (target range: [-6, 6] for NVF4) + const float NVF4_MAX = 6.0f; + float scale = (max_abs > 1e-8f) ? (max_abs / NVF4_MAX) : 1.0f; + float inv_scale = 1.0f / scale; + + // Encode scale as UE4M3 + // UE4M3: value = (1 + mantissa/8) * 2^(exponent - 7) + // We need to find exp and mant such that scale ~= (1 + mant/8) * 2^(exp-7) + + // First, find exponent by getting floor(log2(scale)) and shift to [1,2) range + int exp_raw = 0; + float normalized = scale; + + if (normalized >= 2.0f) { + while (normalized >= 2.0f && exp_raw < 8) { + normalized *= 0.5f; + exp_raw++; + } + } else if (normalized < 1.0f && normalized > 1e-8f) { + while (normalized < 1.0f && exp_raw > -7) { + normalized *= 2.0f; + exp_raw--; + } + } + + // Now normalized is in [1.0, 2.0), compute mantissa + // mantissa = (normalized - 1) * 8, rounded to nearest integer + int mant = __float2int_rn((normalized - 1.0f) * 8.0f); + mant = max(0, min(7, mant)); + + // Compute biased exponent + int exp_biased = exp_raw + 7; + exp_biased = max(0, min(15, exp_biased)); + + uint8_t scale_encoded = ((exp_biased & 0xF) << 3) | (mant & 0x7); + output_scale[scale_block * N + n] = scale_encoded; + + // Recompute actual encoded scale for accurate quantization + float encoded_scale = (1.0f + mant / 8.0f) * ldexpf(1.0f, exp_biased - 7); + inv_scale = 1.0f / encoded_scale; + + // Quantize values to NVF4 + for (int k = k_start; k < k_end; k += 2) { + float v0 = __bfloat162float(input[k * N + n]) * inv_scale; + float v1 = (k + 1 < k_end) ? __bfloat162float(input[(k + 1) * N + n]) * inv_scale : 0.0f; + + // Quantize to NVF4 (nearest value in lookup table) + auto quantize_nvf4 = [](float val) -> uint8_t { + uint8_t sign = (val < 0) ? 0x8 : 0x0; + val = fabsf(val); + if (val < 0.25f) return sign | 0; // 0 + if (val < 0.75f) return sign | 1; // 0.5 + if (val < 1.25f) return sign | 2; // 1.0 + if (val < 1.75f) return sign | 3; // 1.5 + if (val < 2.5f) return sign | 4; // 2.0 + if (val < 3.5f) return sign | 5; // 3.0 + if (val < 5.0f) return sign | 6; // 4.0 + return sign | 7; // 6.0 + }; + + uint8_t q0 = quantize_nvf4(v0); + uint8_t q1 = quantize_nvf4(v1); + + // Pack: low nibble = first element, high nibble = second + int k_packed = k / 2; + output_data[k_packed * N + n] = (q1 << 4) | (q0 & 0x0F); + } +} + +/** + * Launch quantization kernel + */ +inline cudaError_t quantize_bf16_to_nvf4( + const __nv_bfloat16* input, + uint8_t* output_data, + uint8_t* output_scale, + int K, + int N, + cudaStream_t stream = nullptr +) { + const int SCALE_BLOCK = 32; + int num_scale_blocks = (K + SCALE_BLOCK - 1) / SCALE_BLOCK; + + dim3 block(256); + dim3 grid((N + 255) / 256, num_scale_blocks); + + quantize_bf16_to_nvf4_kernel<<>>( + input, output_data, output_scale, K, N + ); + + return cudaGetLastError(); +} + +// ============================================================================ +// High-Level API +// ============================================================================ + +/** + * Check if NVF4 GEMV is available (SM120+) + */ +inline bool is_available() { + int device_id = 0; + cudaGetDevice(&device_id); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device_id); + return (props.major == 12); // SM120/SM121 +} + +} // namespace gemv_nvf4 +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/gemv/test_gemv.cu b/native/ops/gemv/test_gemv.cu new file mode 100644 index 0000000..ef73c8e --- /dev/null +++ b/native/ops/gemv/test_gemv.cu @@ -0,0 +1,433 @@ +/** + * GEMV Correctness Test + * + * Verifies CUTLASS GEMV against CPU reference implementation. + * No cuBLASLt dependency. + * + * Build: + * nvcc -std=c++17 -O3 -arch=sm_86 test_gemv.cu -o test_gemv + * + * Usage: + * ./test_gemv + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "gemv_cutlass.cuh" + +// ============================================================================ +// CPU Reference Implementation +// ============================================================================ + +void gemv_cpu_reference( + const float* A, // [1, K] + const float* B, // [K, N] + float* C, // [1, N] + int K, int N, + float alpha, float beta +) { + for (int n = 0; n < N; ++n) { + float acc = 0.0f; + for (int k = 0; k < K; ++k) { + acc += A[k] * B[k * N + n]; + } + if (beta != 0.0f) { + C[n] = alpha * acc + beta * C[n]; + } else { + C[n] = alpha * acc; + } + } +} + +// ============================================================================ +// Test Functions +// ============================================================================ + +bool test_gemv_bf16(int K, int N, float tolerance = 0.01f) { + printf("Testing BF16 GEMV: K=%d, N=%d ... ", K, N); + + // Host allocations + std::vector h_A(K); + std::vector h_B(K * N); + std::vector h_C_ref(N, 0.0f); + std::vector<__nv_bfloat16> h_A_bf16(K); + std::vector<__nv_bfloat16> h_B_bf16(K * N); + std::vector<__nv_bfloat16> h_C_bf16(N); + + // Initialize with random data + srand(42); + for (int i = 0; i < K; ++i) { + h_A[i] = (static_cast(rand()) / RAND_MAX - 0.5f) * 0.2f; + h_A_bf16[i] = __float2bfloat16(h_A[i]); + } + for (int i = 0; i < K * N; ++i) { + h_B[i] = (static_cast(rand()) / RAND_MAX - 0.5f) * 0.2f; + h_B_bf16[i] = __float2bfloat16(h_B[i]); + } + + // CPU reference (using BF16-rounded values for fair comparison) + std::vector h_A_rounded(K); + std::vector h_B_rounded(K * N); + for (int i = 0; i < K; ++i) { + h_A_rounded[i] = __bfloat162float(h_A_bf16[i]); + } + for (int i = 0; i < K * N; ++i) { + h_B_rounded[i] = __bfloat162float(h_B_bf16[i]); + } + gemv_cpu_reference(h_A_rounded.data(), h_B_rounded.data(), h_C_ref.data(), K, N, 1.0f, 0.0f); + + // Device allocations + __nv_bfloat16 *d_A, *d_B, *d_C; + cudaMalloc(&d_A, K * sizeof(__nv_bfloat16)); + cudaMalloc(&d_B, K * N * sizeof(__nv_bfloat16)); + cudaMalloc(&d_C, N * sizeof(__nv_bfloat16)); + + cudaMemcpy(d_A, h_A_bf16.data(), K * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice); + cudaMemcpy(d_B, h_B_bf16.data(), K * N * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice); + cudaMemset(d_C, 0, N * sizeof(__nv_bfloat16)); + + // Run GPU kernel + cudaError_t err = pygpukit::ops::gemv::launch_gemv_bf16(d_A, d_B, d_C, K, N); + if (err != cudaSuccess) { + printf("FAILED (kernel launch error: %s)\n", cudaGetErrorString(err)); + cudaFree(d_A); + cudaFree(d_B); + cudaFree(d_C); + return false; + } + cudaDeviceSynchronize(); + + // Copy back results + cudaMemcpy(h_C_bf16.data(), d_C, N * sizeof(__nv_bfloat16), cudaMemcpyDeviceToHost); + + // Compare results + float max_err = 0.0f; + float max_rel_err = 0.0f; + int max_err_idx = 0; + for (int i = 0; i < N; ++i) { + float gpu_val = __bfloat162float(h_C_bf16[i]); + float ref_val = h_C_ref[i]; + float err = std::abs(gpu_val - ref_val); + float rel_err = err / (std::abs(ref_val) + 1e-6f); + if (err > max_err) { + max_err = err; + max_err_idx = i; + } + max_rel_err = std::max(max_rel_err, rel_err); + } + + // Cleanup + cudaFree(d_A); + cudaFree(d_B); + cudaFree(d_C); + + if (max_rel_err < tolerance) { + printf("PASS (max_rel_err=%.6f at idx=%d)\n", max_rel_err, max_err_idx); + return true; + } else { + printf("FAILED (max_rel_err=%.6f at idx=%d, ref=%.6f, gpu=%.6f)\n", + max_rel_err, max_err_idx, h_C_ref[max_err_idx], + __bfloat162float(h_C_bf16[max_err_idx])); + return false; + } +} + +bool test_gemv_fp16(int K, int N, float tolerance = 0.005f) { + printf("Testing FP16 GEMV: K=%d, N=%d ... ", K, N); + + // Host allocations + std::vector h_A(K); + std::vector h_B(K * N); + std::vector h_C_ref(N, 0.0f); + std::vector<__half> h_A_fp16(K); + std::vector<__half> h_B_fp16(K * N); + std::vector<__half> h_C_fp16(N); + + // Initialize with random data + srand(42); + for (int i = 0; i < K; ++i) { + h_A[i] = (static_cast(rand()) / RAND_MAX - 0.5f) * 0.2f; + h_A_fp16[i] = __float2half(h_A[i]); + } + for (int i = 0; i < K * N; ++i) { + h_B[i] = (static_cast(rand()) / RAND_MAX - 0.5f) * 0.2f; + h_B_fp16[i] = __float2half(h_B[i]); + } + + // CPU reference (using FP16-rounded values) + std::vector h_A_rounded(K); + std::vector h_B_rounded(K * N); + for (int i = 0; i < K; ++i) { + h_A_rounded[i] = __half2float(h_A_fp16[i]); + } + for (int i = 0; i < K * N; ++i) { + h_B_rounded[i] = __half2float(h_B_fp16[i]); + } + gemv_cpu_reference(h_A_rounded.data(), h_B_rounded.data(), h_C_ref.data(), K, N, 1.0f, 0.0f); + + // Device allocations + __half *d_A, *d_B, *d_C; + cudaMalloc(&d_A, K * sizeof(__half)); + cudaMalloc(&d_B, K * N * sizeof(__half)); + cudaMalloc(&d_C, N * sizeof(__half)); + + cudaMemcpy(d_A, h_A_fp16.data(), K * sizeof(__half), cudaMemcpyHostToDevice); + cudaMemcpy(d_B, h_B_fp16.data(), K * N * sizeof(__half), cudaMemcpyHostToDevice); + cudaMemset(d_C, 0, N * sizeof(__half)); + + // Run GPU kernel + cudaError_t err = pygpukit::ops::gemv::launch_gemv_fp16(d_A, d_B, d_C, K, N); + if (err != cudaSuccess) { + printf("FAILED (kernel launch error: %s)\n", cudaGetErrorString(err)); + cudaFree(d_A); + cudaFree(d_B); + cudaFree(d_C); + return false; + } + cudaDeviceSynchronize(); + + // Copy back results + cudaMemcpy(h_C_fp16.data(), d_C, N * sizeof(__half), cudaMemcpyDeviceToHost); + + // Compare results + float max_rel_err = 0.0f; + int max_err_idx = 0; + for (int i = 0; i < N; ++i) { + float gpu_val = __half2float(h_C_fp16[i]); + float ref_val = h_C_ref[i]; + float err = std::abs(gpu_val - ref_val); + float rel_err = err / (std::abs(ref_val) + 1e-6f); + if (rel_err > max_rel_err) { + max_rel_err = rel_err; + max_err_idx = i; + } + } + + // Cleanup + cudaFree(d_A); + cudaFree(d_B); + cudaFree(d_C); + + if (max_rel_err < tolerance) { + printf("PASS (max_rel_err=%.6f)\n", max_rel_err); + return true; + } else { + printf("FAILED (max_rel_err=%.6f)\n", max_rel_err); + return false; + } +} + +bool test_gemv_fp32(int K, int N, float tolerance = 0.002f) { + printf("Testing FP32 GEMV: K=%d, N=%d ... ", K, N); + + // Host allocations + std::vector h_A(K); + std::vector h_B(K * N); + std::vector h_C_ref(N, 0.0f); + std::vector h_C_gpu(N, 0.0f); + + // Initialize with random data + srand(42); + for (int i = 0; i < K; ++i) { + h_A[i] = (static_cast(rand()) / RAND_MAX - 0.5f) * 0.2f; + } + for (int i = 0; i < K * N; ++i) { + h_B[i] = (static_cast(rand()) / RAND_MAX - 0.5f) * 0.2f; + } + + // CPU reference + gemv_cpu_reference(h_A.data(), h_B.data(), h_C_ref.data(), K, N, 1.0f, 0.0f); + + // Device allocations + float *d_A, *d_B, *d_C; + cudaMalloc(&d_A, K * sizeof(float)); + cudaMalloc(&d_B, K * N * sizeof(float)); + cudaMalloc(&d_C, N * sizeof(float)); + + cudaMemcpy(d_A, h_A.data(), K * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(d_B, h_B.data(), K * N * sizeof(float), cudaMemcpyHostToDevice); + cudaMemset(d_C, 0, N * sizeof(float)); + + // Run GPU kernel + cudaError_t err = pygpukit::ops::gemv::launch_gemv_fp32(d_A, d_B, d_C, K, N); + if (err != cudaSuccess) { + printf("FAILED (kernel launch error: %s)\n", cudaGetErrorString(err)); + cudaFree(d_A); + cudaFree(d_B); + cudaFree(d_C); + return false; + } + cudaDeviceSynchronize(); + + // Copy back results + cudaMemcpy(h_C_gpu.data(), d_C, N * sizeof(float), cudaMemcpyDeviceToHost); + + // Compare results + float max_rel_err = 0.0f; + for (int i = 0; i < N; ++i) { + float err = std::abs(h_C_gpu[i] - h_C_ref[i]); + float rel_err = err / (std::abs(h_C_ref[i]) + 1e-6f); + max_rel_err = std::max(max_rel_err, rel_err); + } + + // Cleanup + cudaFree(d_A); + cudaFree(d_B); + cudaFree(d_C); + + if (max_rel_err < tolerance) { + printf("PASS (max_rel_err=%.6f)\n", max_rel_err); + return true; + } else { + printf("FAILED (max_rel_err=%.6f)\n", max_rel_err); + return false; + } +} + +bool test_gemv_batched_bf16(int batch, int K, int N, float tolerance = 0.01f) { + printf("Testing Batched BF16 GEMV: batch=%d, K=%d, N=%d ... ", batch, K, N); + + // Host allocations + std::vector h_A(batch * K); + std::vector h_B(K * N); + std::vector h_C_ref(batch * N, 0.0f); + std::vector<__nv_bfloat16> h_A_bf16(batch * K); + std::vector<__nv_bfloat16> h_B_bf16(K * N); + std::vector<__nv_bfloat16> h_C_bf16(batch * N); + + // Initialize + srand(42); + for (int i = 0; i < batch * K; ++i) { + h_A[i] = (static_cast(rand()) / RAND_MAX - 0.5f) * 0.2f; + h_A_bf16[i] = __float2bfloat16(h_A[i]); + } + for (int i = 0; i < K * N; ++i) { + h_B[i] = (static_cast(rand()) / RAND_MAX - 0.5f) * 0.2f; + h_B_bf16[i] = __float2bfloat16(h_B[i]); + } + + // CPU reference (per batch) + for (int b = 0; b < batch; ++b) { + std::vector h_A_rounded(K); + std::vector h_B_rounded(K * N); + for (int i = 0; i < K; ++i) { + h_A_rounded[i] = __bfloat162float(h_A_bf16[b * K + i]); + } + for (int i = 0; i < K * N; ++i) { + h_B_rounded[i] = __bfloat162float(h_B_bf16[i]); + } + gemv_cpu_reference(h_A_rounded.data(), h_B_rounded.data(), + h_C_ref.data() + b * N, K, N, 1.0f, 0.0f); + } + + // Device allocations + __nv_bfloat16 *d_A, *d_B, *d_C; + cudaMalloc(&d_A, batch * K * sizeof(__nv_bfloat16)); + cudaMalloc(&d_B, K * N * sizeof(__nv_bfloat16)); + cudaMalloc(&d_C, batch * N * sizeof(__nv_bfloat16)); + + cudaMemcpy(d_A, h_A_bf16.data(), batch * K * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice); + cudaMemcpy(d_B, h_B_bf16.data(), K * N * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice); + cudaMemset(d_C, 0, batch * N * sizeof(__nv_bfloat16)); + + // Run GPU kernel + cudaError_t err = pygpukit::ops::gemv::launch_gemv_bf16_batched( + d_A, d_B, d_C, K, N, batch); + if (err != cudaSuccess) { + printf("FAILED (kernel launch error: %s)\n", cudaGetErrorString(err)); + cudaFree(d_A); + cudaFree(d_B); + cudaFree(d_C); + return false; + } + cudaDeviceSynchronize(); + + // Copy back results + cudaMemcpy(h_C_bf16.data(), d_C, batch * N * sizeof(__nv_bfloat16), cudaMemcpyDeviceToHost); + + // Compare results + float max_rel_err = 0.0f; + for (int i = 0; i < batch * N; ++i) { + float gpu_val = __bfloat162float(h_C_bf16[i]); + float ref_val = h_C_ref[i]; + float err = std::abs(gpu_val - ref_val); + float rel_err = err / (std::abs(ref_val) + 1e-6f); + max_rel_err = std::max(max_rel_err, rel_err); + } + + // Cleanup + cudaFree(d_A); + cudaFree(d_B); + cudaFree(d_C); + + if (max_rel_err < tolerance) { + printf("PASS (max_rel_err=%.6f)\n", max_rel_err); + return true; + } else { + printf("FAILED (max_rel_err=%.6f)\n", max_rel_err); + return false; + } +} + +// ============================================================================ +// Main +// ============================================================================ + +int main() { + // Print device info + int device; + cudaGetDevice(&device); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device); + printf("Device: %s (SM %d%d)\n", props.name, props.major, props.minor); + printf("\n"); + + printf("=== GEMV Correctness Tests ===\n\n"); + + int passed = 0; + int failed = 0; + + // BF16 tests + printf("--- BF16 GEMV ---\n"); + if (test_gemv_bf16(256, 256)) passed++; else failed++; + if (test_gemv_bf16(512, 512)) passed++; else failed++; + if (test_gemv_bf16(1024, 1024)) passed++; else failed++; + if (test_gemv_bf16(4096, 4096)) passed++; else failed++; + if (test_gemv_bf16(4096, 11008)) passed++; else failed++; // LLaMA MLP + if (test_gemv_bf16(8192, 28672)) passed++; else failed++; // LLaMA-70B MLP + printf("\n"); + + // FP16 tests + printf("--- FP16 GEMV ---\n"); + if (test_gemv_fp16(256, 256)) passed++; else failed++; + if (test_gemv_fp16(1024, 1024)) passed++; else failed++; + if (test_gemv_fp16(4096, 4096)) passed++; else failed++; + printf("\n"); + + // FP32 tests + printf("--- FP32 GEMV ---\n"); + if (test_gemv_fp32(256, 256)) passed++; else failed++; + if (test_gemv_fp32(1024, 1024)) passed++; else failed++; + if (test_gemv_fp32(4096, 4096)) passed++; else failed++; + printf("\n"); + + // Batched BF16 tests + printf("--- Batched BF16 GEMV ---\n"); + if (test_gemv_batched_bf16(4, 1024, 1024)) passed++; else failed++; + if (test_gemv_batched_bf16(8, 4096, 4096)) passed++; else failed++; + if (test_gemv_batched_bf16(16, 4096, 11008)) passed++; else failed++; + printf("\n"); + + // Summary + printf("=== Summary ===\n"); + printf("Passed: %d\n", passed); + printf("Failed: %d\n", failed); + + return failed > 0 ? 1 : 0; +} diff --git a/native/ops/matmul/aligned_copy_sm120.cuh b/native/ops/matmul/aligned_copy_sm120.cuh new file mode 100644 index 0000000..4dbfaef --- /dev/null +++ b/native/ops/matmul/aligned_copy_sm120.cuh @@ -0,0 +1,269 @@ +/** + * Aligned Copy Operations for SM120 FP8 GEMM + * + * Workaround for CUTLASS Issue #2902: + * - partition_S() drops alignment from 1024 to 8 bytes + * - SM75_U32x4_LDSM_N requires 16-byte alignment + * + * This file provides: + * 1. Inline PTX helpers for alignment-safe shared memory loads + * 2. A macro to patch CUTLASS's LDSM operations post-include + * + * Usage: + * // Include this AFTER CUTLASS headers + * #include + * #include "aligned_copy_sm120.cuh" + * + * // The CUTLASS kernel will use patched copy operations + * // if PYGPUKIT_PATCH_CUTLASS_LDSM_POST is defined + */ +#pragma once + +#include +#include + +// ============================================================================ +// Core PTX Helpers for Shared Memory Operations +// ============================================================================ + +namespace pygpukit { +namespace ops { +namespace aligned_copy { + +/** + * Convert shared memory pointer to generic address space (32-bit for PTX) + */ +__device__ __forceinline__ +uint32_t smem_ptr_to_u32(const void* ptr) { +#if defined(__CUDA_ARCH__) + return static_cast(__cvta_generic_to_shared(ptr)); +#else + return 0; +#endif +} + +/** + * Load 4x u32 (16 bytes) from shared memory with alignment check. + * + * IMPORTANT: ldmatrix.sync requires ALL threads in the warp to participate. + * This function assumes it's called by the full warp (CUTLASS pattern). + * For single-thread usage, use ld_shared_u32x4_scalar instead. + * + * Behavior: + * - 16-byte aligned: uses ldmatrix.sync (fast, requires full warp) + * - Misaligned: falls back to scalar loads (slower but always safe) + */ +__device__ __forceinline__ +void ld_shared_u32x4_safe( + uint32_t smem_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + if ((smem_addr & 0xF) == 0) { + // 16-byte aligned: use ldmatrix (fast path) + // NOTE: ldmatrix.sync requires all warp threads to execute this + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_addr) + ); + } else { + // Misaligned: use scalar loads (slow but correct) + asm volatile( + "ld.shared.u32 %0, [%4];\n" + "ld.shared.u32 %1, [%5];\n" + "ld.shared.u32 %2, [%6];\n" + "ld.shared.u32 %3, [%7];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_addr), + "r"(smem_addr + 4u), + "r"(smem_addr + 8u), + "r"(smem_addr + 12u) + ); + } +#endif +} + +/** + * Load 4x u32 with forced alignment (trust caller) + */ +__device__ __forceinline__ +void ld_shared_u32x4_trusted( + uint32_t smem_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_addr) + ); +#endif +} + +/** + * Load 4x u32 using scalar loads only (always safe) + */ +__device__ __forceinline__ +void ld_shared_u32x4_scalar( + uint32_t smem_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) +{ +#if defined(__CUDA_ARCH__) + asm volatile( + "ld.shared.u32 %0, [%4];\n" + "ld.shared.u32 %1, [%5];\n" + "ld.shared.u32 %2, [%6];\n" + "ld.shared.u32 %3, [%7];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_addr), + "r"(smem_addr + 4u), + "r"(smem_addr + 8u), + "r"(smem_addr + 12u) + ); +#endif +} + +/** + * Load 4x u32 with transpose and alignment check + */ +__device__ __forceinline__ +void ld_shared_u32x4_trans_safe( + uint32_t smem_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + if ((smem_addr & 0xF) == 0) { + asm volatile( + "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_addr) + ); + } else { + // Scalar fallback (no transpose - caller must handle) + asm volatile( + "ld.shared.u32 %0, [%4];\n" + "ld.shared.u32 %1, [%5];\n" + "ld.shared.u32 %2, [%6];\n" + "ld.shared.u32 %3, [%7];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_addr), + "r"(smem_addr + 4u), + "r"(smem_addr + 8u), + "r"(smem_addr + 12u) + ); + } +#endif +} + +/** + * Load 2x u32 (8 bytes) with alignment check + */ +__device__ __forceinline__ +void ld_shared_u32x2_safe( + uint32_t smem_addr, + uint32_t& dst0, uint32_t& dst1) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + if ((smem_addr & 0x7) == 0) { + asm volatile( + "ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_addr) + ); + } else { + asm volatile( + "ld.shared.u32 %0, [%2];\n" + "ld.shared.u32 %1, [%3];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_addr), + "r"(smem_addr + 4u) + ); + } +#endif +} + +/** + * Load 1x u32 with ldmatrix + */ +__device__ __forceinline__ +void ld_shared_u32x1(uint32_t smem_addr, uint32_t& dst0) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + asm volatile( + "ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst0) + : "r"(smem_addr) + ); +#endif +} + +} // namespace aligned_copy +} // namespace ops +} // namespace pygpukit + +// ============================================================================ +// CUTLASS Integration Macros +// ============================================================================ + +/** + * Macro to wrap a shared memory load with alignment-safe version. + * Use this in custom kernels or modified CUTLASS mainloops. + * + * Example: + * uint32_t r0, r1, r2, r3; + * PYGPUKIT_SAFE_LDSM_X4(smem_ptr, r0, r1, r2, r3); + */ +#define PYGPUKIT_SAFE_LDSM_X4(smem_ptr, r0, r1, r2, r3) \ + do { \ + uint32_t _addr = pygpukit::ops::aligned_copy::smem_ptr_to_u32(smem_ptr); \ + pygpukit::ops::aligned_copy::ld_shared_u32x4_safe(_addr, r0, r1, r2, r3); \ + } while(0) + +#define PYGPUKIT_SAFE_LDSM_X4_TRANS(smem_ptr, r0, r1, r2, r3) \ + do { \ + uint32_t _addr = pygpukit::ops::aligned_copy::smem_ptr_to_u32(smem_ptr); \ + pygpukit::ops::aligned_copy::ld_shared_u32x4_trans_safe(_addr, r0, r1, r2, r3); \ + } while(0) + +#define PYGPUKIT_SAFE_LDSM_X2(smem_ptr, r0, r1) \ + do { \ + uint32_t _addr = pygpukit::ops::aligned_copy::smem_ptr_to_u32(smem_ptr); \ + pygpukit::ops::aligned_copy::ld_shared_u32x2_safe(_addr, r0, r1); \ + } while(0) + +// ============================================================================ +// Post-Include Patch for CUTLASS SM75 LDSM Operations +// ============================================================================ +// +// IMPORTANT: Include this AFTER cute/arch/copy_sm75.hpp +// +// This redefines the copy() function for SM75 LDSM structs using +// our alignment-safe implementations. +// ============================================================================ + +#if defined(PYGPUKIT_PATCH_CUTLASS_LDSM_POST) && defined(CUTE_ARCH_COPY_SM75_HPP) + +// Ensure the original structs exist +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + +namespace cute { + +// Override SM75_U32x4_LDSM_N::copy with our safe version +// Note: This uses ADL to find our implementation +struct SM75_U32x4_LDSM_N_Safe : SM75_U32x4_LDSM_N { + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + uint32_t addr = pygpukit::ops::aligned_copy::smem_ptr_to_u32(&smem_src); + pygpukit::ops::aligned_copy::ld_shared_u32x4_safe(addr, dst0, dst1, dst2, dst3); +#endif + } +}; + +} // namespace cute + +#endif // CUTE_ARCH_LDSM_SM75_ACTIVATED +#endif // PYGPUKIT_PATCH_CUTLASS_LDSM_POST && CUTE_ARCH_COPY_SM75_HPP diff --git a/native/ops/matmul/build_fp8_test.bat b/native/ops/matmul/build_fp8_test.bat new file mode 100644 index 0000000..4add1ea --- /dev/null +++ b/native/ops/matmul/build_fp8_test.bat @@ -0,0 +1,46 @@ +@echo off +REM Build FP8 GEMM test with CUTLASS alignment patch +REM This tests if the alignment fix enables FP8 to work on SM120 + +set SCRIPT_DIR=%~dp0 +cd /d %SCRIPT_DIR% + +call "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" + +set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1 +set CUTLASS_PATH=%SCRIPT_DIR%..\..\..\third_party\cutlass\include +set CUTLASS_TOOLS_PATH=%SCRIPT_DIR%..\..\..\third_party\cutlass\tools\util\include +set PATH=%CUDA_PATH%\bin;%PATH% + +echo. +echo Current directory: %CD% +echo CUTLASS path: %CUTLASS_PATH% +echo CUTLASS tools path: %CUTLASS_TOOLS_PATH% +echo. +echo Building test_fp8_patched.cu for SM120a (architecture-specific features)... +echo. + +REM Use sm_120a to enable __CUDA_ARCH_FEAT_SM120_ALL macro +REM This is required for CUTLASS kernel selection (Issue #2902 workaround) +REM Add -DPYGPUKIT_DEBUG_LDSM to enable printf debugging in LDSM operations +nvcc -arch=sm_120a -std=c++17 -O3 ^ + -I"%CUTLASS_PATH%" ^ + -I"%CUTLASS_TOOLS_PATH%" ^ + -DCUTLASS_ARCH_MMA_SM120_SUPPORTED ^ + -DPYGPUKIT_DEBUG_LDSM ^ + --expt-relaxed-constexpr ^ + -Xcompiler "/Zc:preprocessor" ^ + -o test_fp8_patched.exe test_fp8_patched.cu + +if errorlevel 1 ( + echo. + echo Build failed! + exit /b 1 +) + +echo. +echo Build succeeded! +echo. +echo Running test... +echo. +test_fp8_patched.exe diff --git a/native/ops/matmul/matmul.cu b/native/ops/matmul/matmul.cu index 268a398..0d46194 100644 --- a/native/ops/matmul/matmul.cu +++ b/native/ops/matmul/matmul.cu @@ -16,6 +16,7 @@ #include "../matmul_f16_bf16_tc.cuh" #include "../matmul_f16_bf16_tc_generic.cuh" #include "../matmul_cublaslt.cuh" +#include "../matmul_cutlass.cuh" #include #include @@ -78,19 +79,19 @@ void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c) { // Only check native TensorCore settings if CUTLASS is disabled if (!cutlass_enabled) { + sm_version = get_sm_version(); const char* tf32_env = std::getenv("PYGPUKIT_ALLOW_TF32"); const char* fp16_tc_env = std::getenv("PYGPUKIT_ALLOW_FP16_TC"); - if ((tf32_env && (tf32_env[0] == '1' || tf32_env[0] == 'y' || tf32_env[0] == 'Y')) || - (fp16_tc_env && (fp16_tc_env[0] == '1' || fp16_tc_env[0] == 'y' || fp16_tc_env[0] == 'Y'))) { - sm_version = get_sm_version(); - } + // On SM 120+ where CUTLASS doesn't work, automatically enable TF32 TensorCore + // This provides good performance fallback for Blackwell GeForce (RTX 5090) + bool auto_tf32 = (sm_version >= 120); - if (tf32_env && (tf32_env[0] == '1' || tf32_env[0] == 'y' || tf32_env[0] == 'Y')) { + if (auto_tf32 || (tf32_env && (tf32_env[0] == '1' || tf32_env[0] == 'y' || tf32_env[0] == 'Y'))) { tf32_enabled = (sm_version >= MIN_SM_VERSION); } - if (fp16_tc_env && (fp16_tc_env[0] == '1' || fp16_tc_env[0] == 'y' || fp16_tc_env[0] == 'Y')) { + if ((fp16_tc_env && (fp16_tc_env[0] == '1' || fp16_tc_env[0] == 'y' || fp16_tc_env[0] == 'Y'))) { fp16_tc_enabled = (sm_version >= MIN_SM_VERSION); } } @@ -626,5 +627,38 @@ GPUArray linear_bias_gelu(const GPUArray& input, const GPUArray& weight, const G return output; } +// ============================================================================ +// Batched GEMM Implementation +// ============================================================================ + +void batched_matmul_fp32(const GPUArray& A, const GPUArray& B, GPUArray& C, + int M, int N, int K, int batch_count, + int64_t strideA, int64_t strideB, int64_t strideC) { + // Validate inputs + if (A.dtype() != DataType::Float32 || B.dtype() != DataType::Float32 || C.dtype() != DataType::Float32) { + throw std::runtime_error("batched_matmul_fp32: all inputs must be float32"); + } + +#if PYGPUKIT_HAS_CUTLASS + // Use CUTLASS batched GEMM + cudaError_t err = cutlass_gemm::gemm_batched_fp32( + static_cast(A.data()), + static_cast(B.data()), + static_cast(C.data()), + M, N, K, + batch_count, + strideA, strideB, strideC, + 1.0f, 0.0f, // alpha, beta + internal::get_capture_stream() + ); + if (err != cudaSuccess) { + throw std::runtime_error("batched_matmul_fp32: CUTLASS kernel failed"); + } + sync_and_check("batched_matmul_fp32 CUTLASS kernel failed"); +#else + throw std::runtime_error("batched_matmul_fp32: CUTLASS not available"); +#endif +} + } // namespace ops } // namespace pygpukit diff --git a/native/ops/matmul/matmul_fp8_fp32_sm120.cu b/native/ops/matmul/matmul_fp8_fp32_sm120.cu new file mode 100644 index 0000000..782bfb0 --- /dev/null +++ b/native/ops/matmul/matmul_fp8_fp32_sm120.cu @@ -0,0 +1,501 @@ +/** + * FP8 GEMM implementation for SM120 (Blackwell GeForce) + * + * Path: + * 1. FP32 input + * 2. FP8 quantization (A scale, B scale separate) + * 3. FP8 CUTLASS GEMM + * 4. BF16 accumulate + * 5. FP32 output (if needed) + * + * Implementation based on CUTLASS example 87a: + * "87a_blackwell_geforce_fp8_bf16_gemm_blockwise" + * + * IMPORTANT: This is the ONLY backend for SM120. No cuBLAS fallback. + * + * WORKAROUND for CUTLASS bug #2902: + * - partition_S() drops alignment from 1024 to 8 bytes + * - SM75_U32x4_LDSM_N requires 16-byte alignment + * - We patch the LDSM copy operations to handle misalignment + * - Tracking issue: https://github.com/NVIDIA/cutlass/issues/2902 + * - Local issue: https://github.com/m96-chan/PyGPUkit/issues/107 + */ + +#include +#include +#include +#include + +// Enable FP8 SM120 with alignment patch +#define PYGPUKIT_ENABLE_FP8_SM120 + +// Only compile for SM120+ AND when explicitly enabled +#if (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) && defined(PYGPUKIT_ENABLE_FP8_SM120) + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/blockwise_scale_layout.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/device_memory.h" + +// ============================================================================ +// ALIGNMENT PATCH: Include AFTER CUTLASS headers +// Provides alignment-safe LDSM operations for Issue #2902 workaround +// ============================================================================ +#define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 +#include "aligned_copy_sm120.cuh" + +using namespace cute; + +namespace pygpukit { +namespace ops { +namespace fp8_gemm_sm120 { + +// ============================================================================ +// GEMM Configuration: FP8 E4M3 x FP8 E4M3 -> BF16 with blockwise scaling +// Based on CUTLASS example 87a_blackwell_geforce_fp8_bf16_gemm_blockwise +// Using OpClassTensorOp for SM120 GeForce (NOT OpClassBlockScaledTensorOp) +// ============================================================================ + +// A matrix: FP8 E4M3, RowMajor +using ElementA = cutlass::float_e4m3_t; +using LayoutATag = cutlass::layout::RowMajor; +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + +// B matrix: FP8 E4M3, ColumnMajor +using ElementB = cutlass::float_e4m3_t; +using LayoutBTag = cutlass::layout::ColumnMajor; +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + +// Output: BF16 +using ElementC = cutlass::bfloat16_t; +using ElementD = cutlass::bfloat16_t; +using LayoutCTag = cutlass::layout::RowMajor; +using LayoutDTag = cutlass::layout::RowMajor; +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentD = AlignmentC; + +// Accumulator type +using ElementAccumulator = float; +using ElementCompute = float; + +// SM120 GeForce architecture with TensorOp +using ArchTag = cutlass::arch::Sm120; +using OperatorClass = cutlass::arch::OpClassTensorOp; + +// MMA and Cluster Tile Shapes +using MmaTileShape_MNK = Shape<_128, _128, _128>; +using ClusterShape_MNK = Shape<_1, _1, _1>; // GeForce: no cluster support + +// Scale configuration (trivial blockwise scaling from example 87a) +using ScaleConfig = decltype(cutlass::detail::sm120_trivial_blockwise_scale_config(MmaTileShape_MNK{})); +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + +// Epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto +>::CollectiveOp; + +// Mainloop with scale factor layouts +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto +>::CollectiveOp; + +// GEMM Kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + void // Default CLC scheduler +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Stride and Layout types +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// ============================================================================ +// FP32 -> FP8 E4M3 Quantization with blockwise scaling +// ============================================================================ + +constexpr float FP8_E4M3_MAX = 448.0f; + +__device__ __forceinline__ +uint8_t float_to_fp8_e4m3_scaled(float val, float inv_scale) { + // Apply inverse scale + val = val * inv_scale; + + // Clamp to FP8 E4M3 range + val = fminf(fmaxf(val, -FP8_E4M3_MAX), FP8_E4M3_MAX); + if (fabsf(val) < 1e-7f) return 0; + + uint32_t bits = __float_as_uint(val); + uint8_t sign = (bits >> 24) & 0x80; + int exp = ((bits >> 23) & 0xFF) - 127 + 7; // FP8 E4M3 bias = 7 + uint32_t mant = bits & 0x7FFFFF; + + if (exp <= 0) return sign; + if (exp >= 15) return sign | 0x7E; // Max FP8 E4M3 + + return sign | (static_cast(exp) << 3) | static_cast(mant >> 20); +} + +// Simple FP32 -> FP8 conversion kernel (unity scale for testing) +__global__ void quantize_fp32_to_fp8_kernel( + const float* __restrict__ input, + cutlass::float_e4m3_t* __restrict__ output, + int64_t num_elements +) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= num_elements) return; + + // Simple quantization with unity scale (inv_scale = 1.0) + uint8_t fp8 = float_to_fp8_e4m3_scaled(input[idx], 1.0f); + output[idx] = cutlass::float_e4m3_t::bitcast(fp8); +} + +// Transpose and quantize B from RowMajor [K,N] to ColumnMajor [K,N] +// Input: B_row[k,n] = B[k * N + n] (RowMajor) +// Output: B_col[k,n] = B[k + n * K] (ColumnMajor) +__global__ void transpose_quantize_fp32_to_fp8_kernel( + const float* __restrict__ input, // [K, N] RowMajor + cutlass::float_e4m3_t* __restrict__ output, // [K, N] ColumnMajor + int K, int N +) { + int k = blockIdx.y * blockDim.y + threadIdx.y; + int n = blockIdx.x * blockDim.x + threadIdx.x; + + if (k >= K || n >= N) return; + + // Read from RowMajor: B[k,n] = input[k * N + n] + float val = input[k * N + n]; + + // Write to ColumnMajor: B[k,n] = output[k + n * K] + uint8_t fp8 = float_to_fp8_e4m3_scaled(val, 1.0f); + output[k + n * K] = cutlass::float_e4m3_t::bitcast(fp8); +} + +// Fill scale factors with unity (1.0f) +// Example 87a uses float scale factors, not E8M0 +__global__ void fill_scale_factors_unity_kernel( + float* __restrict__ scales, + size_t num_scales +) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= num_scales) return; + + scales[idx] = 1.0f; +} + +// ============================================================================ +// BF16 -> FP32 Conversion +// ============================================================================ + +__global__ void bf16_to_fp32_kernel( + const cutlass::bfloat16_t* __restrict__ input, + float* __restrict__ output, + int64_t num_elements +) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= num_elements) return; + + output[idx] = static_cast(input[idx]); +} + +// ============================================================================ +// FP8 GEMM Entry Point +// ============================================================================ + +cudaError_t gemm_fp8( + const float* A, // [M, K] FP32 input + const float* B, // [K, N] FP32 input (will be transposed internally) + float* D, // [M, N] FP32 output + int M, int N, int K, + float alpha, + float beta, + cudaStream_t stream +) { + fprintf(stderr, "[FP8 GEMM SM120] BUILD_VER=2024-12-24-A\n"); + fprintf(stderr, "[FP8 GEMM SM120] Starting M=%d, N=%d, K=%d\n", M, N, K); + + // Check input/output alignment + fprintf(stderr, "[FP8 GEMM SM120] Alignment check:\n"); + fprintf(stderr, " A ptr alignment mod 128 = %llu\n", (unsigned long long)((uintptr_t)A % 128)); + fprintf(stderr, " B ptr alignment mod 128 = %llu\n", (unsigned long long)((uintptr_t)B % 128)); + fprintf(stderr, " D ptr alignment mod 128 = %llu\n", (unsigned long long)((uintptr_t)D % 128)); + + // Sizes + int64_t size_A = static_cast(M) * K; + int64_t size_B = static_cast(K) * N; + int64_t size_D = static_cast(M) * N; + + // Allocate FP8 data buffers + cutlass::device_memory::allocation buf_A_fp8(size_A); + cutlass::device_memory::allocation buf_B_fp8(size_B); + cutlass::device_memory::allocation buf_C_bf16(size_D); // For epilogue C input + cutlass::device_memory::allocation buf_D_bf16(size_D); + + auto* d_A_fp8 = buf_A_fp8.get(); + auto* d_B_fp8 = buf_B_fp8.get(); + auto* d_C_bf16 = buf_C_bf16.get(); + auto* d_D_bf16 = buf_D_bf16.get(); + + fprintf(stderr, "[FP8 GEMM SM120] FP8 buffers allocated: A=%p, B=%p, D_bf16=%p\n", + (void*)d_A_fp8, (void*)d_B_fp8, (void*)d_D_bf16); + fprintf(stderr, "[FP8 GEMM SM120] Internal alignment check:\n"); + fprintf(stderr, " A_fp8 mod 128 = %llu\n", (unsigned long long)((uintptr_t)d_A_fp8 % 128)); + fprintf(stderr, " B_fp8 mod 128 = %llu\n", (unsigned long long)((uintptr_t)d_B_fp8 % 128)); + fprintf(stderr, " D_bf16 mod 128 = %llu\n", (unsigned long long)((uintptr_t)d_D_bf16 % 128)); + + // Calculate scale factor sizes using ScaleConfig (from example 87a) + auto problem_shape = cute::make_shape(M, N, K, 1); + LayoutSFA layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(problem_shape); + LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(problem_shape); + + size_t sfa_size = size(filter_zeros(layout_SFA)); + size_t sfb_size = size(filter_zeros(layout_SFB)); + + fprintf(stderr, "[FP8 GEMM SM120] Scale factor sizes: SFA=%zu, SFB=%zu\n", sfa_size, sfb_size); + fprintf(stderr, "[FP8 GEMM SM120] Scale factor layouts:\n"); + cute::print(" layout_SFA: "); cute::print(layout_SFA); cute::print("\n"); + cute::print(" layout_SFB: "); cute::print(layout_SFB); cute::print("\n"); + + // Allocate scale factor buffers (float, not E8M0) + // TMA requires 128-byte alignment for each scale factor access + // Pad to at least 32 floats (128 bytes) to ensure TMA alignment + size_t sfa_padded = std::max(sfa_size, size_t(32)); + size_t sfb_padded = std::max(sfb_size, size_t(32)); + fprintf(stderr, "[FP8 GEMM SM120] Scale factor padded sizes: SFA=%zu->%zu, SFB=%zu->%zu\n", + sfa_size, sfa_padded, sfb_size, sfb_padded); + + cutlass::device_memory::allocation buf_SFA(sfa_padded); + cutlass::device_memory::allocation buf_SFB(sfb_padded); + + auto* d_SFA = buf_SFA.get(); + auto* d_SFB = buf_SFB.get(); + + fprintf(stderr, "[FP8 GEMM SM120] Scale factor alignment:\n"); + fprintf(stderr, " SFA mod 128 = %llu\n", (unsigned long long)((uintptr_t)d_SFA % 128)); + fprintf(stderr, " SFB mod 128 = %llu\n", (unsigned long long)((uintptr_t)d_SFB % 128)); + + // Quantize A and B + int threads = 256; + int blocks_A_data = (size_A + threads - 1) / threads; + + // Convert A: FP32 -> FP8 (keep RowMajor) + quantize_fp32_to_fp8_kernel<<>>( + A, d_A_fp8, size_A + ); + + // Convert B: FP32 RowMajor -> FP8 ColumnMajor (transpose during quantization) + // B input is [K, N] RowMajor, output needs to be [K, N] ColumnMajor + dim3 block_B(16, 16); + dim3 grid_B((N + 15) / 16, (K + 15) / 16); + transpose_quantize_fp32_to_fp8_kernel<<>>( + B, d_B_fp8, K, N + ); + fprintf(stderr, "[FP8 GEMM SM120] B transposed from RowMajor to ColumnMajor\n"); + + // Fill scale factors with 1.0 (fill entire padded buffer) + int blocks_SFA_fill = (sfa_padded + threads - 1) / threads; + int blocks_SFB_fill = (sfb_padded + threads - 1) / threads; + fill_scale_factors_unity_kernel<<>>(d_SFA, sfa_padded); + fill_scale_factors_unity_kernel<<>>(d_SFB, sfb_padded); + + // Sync and check for errors + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "[FP8 GEMM SM120] Quantization sync failed: %s\n", cudaGetErrorString(err)); + return err; + } + fprintf(stderr, "[FP8 GEMM SM120] Quantization OK\n"); + + // Build strides (from example 87a) + // For CUTLASS 3.x with cute layouts: + // - StrideA for RowMajor A[M,K]: packed stride from shape (M, K, L) + // - StrideB for ColumnMajor B[K,N]: packed stride from shape (N, K, L) + // Note: The shape passed to make_cute_packed_stride is the logical GEMM shape, + // not the memory layout shape. CUTLASS handles the layout internally. + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1)); + StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, 1)); + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, 1)); + + // Debug: Print stride values + fprintf(stderr, "[FP8 GEMM SM120] Stride debug:\n"); + fprintf(stderr, " stride_a: (%lld, %lld, %lld)\n", + (long long)cute::get<0>(stride_a), (long long)cute::get<1>(stride_a), (long long)cute::get<2>(stride_a)); + fprintf(stderr, " stride_b: (%lld, %lld, %lld)\n", + (long long)cute::get<0>(stride_b), (long long)cute::get<1>(stride_b), (long long)cute::get<2>(stride_b)); + fprintf(stderr, " stride_c: (%lld, %lld, %lld)\n", + (long long)cute::get<0>(stride_c), (long long)cute::get<1>(stride_c), (long long)cute::get<2>(stride_c)); + fprintf(stderr, " stride_d: (%lld, %lld, %lld)\n", + (long long)cute::get<0>(stride_d), (long long)cute::get<1>(stride_d), (long long)cute::get<2>(stride_d)); + + // Build CUTLASS arguments (following example 87a structure) + // Note: Even with beta=0, we must pass a valid C pointer (CUTLASS may dereference it) + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + { // Mainloop arguments + d_A_fp8, stride_a, + d_B_fp8, stride_b, + d_SFA, layout_SFA, + d_SFB, layout_SFB + }, + { // Epilogue arguments + {}, // epilogue.thread (will be filled below) + d_C_bf16, stride_c, // C pointer (valid even with beta=0) + d_D_bf16, stride_d // D pointer + } + }; + + // Set alpha/beta + arguments.epilogue.thread.alpha = alpha; + arguments.epilogue.thread.beta = beta; + + fprintf(stderr, "[FP8 GEMM SM120] Arguments built, alpha=%f, beta=%f\n", alpha, beta); + + // Instantiate and run GEMM + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8 GEMM SM120] can_implement failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + fprintf(stderr, "[FP8 GEMM SM120] can_implement OK\n"); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + fprintf(stderr, "[FP8 GEMM SM120] Workspace size: %zu bytes\n", workspace_size); + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8 GEMM SM120] initialize failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + fprintf(stderr, "[FP8 GEMM SM120] initialize OK\n"); + + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8 GEMM SM120] run failed: %d\n", static_cast(status)); + return cudaErrorLaunchFailure; + } + + // Sync and check for kernel errors + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "[FP8 GEMM SM120] GEMM sync failed: %s\n", cudaGetErrorString(err)); + return err; + } + err = cudaGetLastError(); + if (err != cudaSuccess) { + fprintf(stderr, "[FP8 GEMM SM120] GEMM kernel error: %s\n", cudaGetErrorString(err)); + return err; + } + fprintf(stderr, "[FP8 GEMM SM120] GEMM completed OK\n"); + + // Convert BF16 output to FP32 + int blocks_D = (size_D + threads - 1) / threads; + bf16_to_fp32_kernel<<>>(d_D_bf16, D, size_D); + + // Sync before RAII cleanup + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "[FP8 GEMM SM120] BF16->FP32 sync failed: %s\n", cudaGetErrorString(err)); + return err; + } + fprintf(stderr, "[FP8 GEMM SM120] Complete\n"); + + return cudaSuccess; +} + +bool is_available() { + int device_id = 0; + cudaGetDevice(&device_id); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device_id); + return (props.major * 10 + props.minor) >= 120; +} + +} // namespace fp8_gemm_sm120 +} // namespace ops +} // namespace pygpukit + +// Extern C for linking +extern "C" { + cudaError_t pygpukit_gemm_fp8_sm120( + const float* A, const float* B, float* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return pygpukit::ops::fp8_gemm_sm120::gemm_fp8(A, B, D, M, N, K, alpha, beta, stream); + } + + bool pygpukit_fp8_sm120_available() { + return pygpukit::ops::fp8_gemm_sm120::is_available(); + } +} + +#else // !SM120 + +namespace pygpukit { +namespace ops { +namespace fp8_gemm_sm120 { + +cudaError_t gemm_fp8( + const float* A, const float* B, float* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream +) { + return cudaErrorNotSupported; +} + +bool is_available() { + return false; +} + +} // namespace fp8_gemm_sm120 +} // namespace ops +} // namespace pygpukit + +extern "C" { + cudaError_t pygpukit_gemm_fp8_sm120( + const float* A, const float* B, float* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return cudaErrorNotSupported; + } + + bool pygpukit_fp8_sm120_available() { + return false; + } +} + +#endif diff --git a/native/ops/matmul/matmul_fp8_fp8_sm120.cu b/native/ops/matmul/matmul_fp8_fp8_sm120.cu new file mode 100644 index 0000000..2fd98a6 --- /dev/null +++ b/native/ops/matmul/matmul_fp8_fp8_sm120.cu @@ -0,0 +1,478 @@ +/** + * Pure FP8 GEMM implementation for SM120 (Blackwell GeForce) + * + * Path: + * 1. FP8 E4M3 input (A, B already quantized) + * 2. FP8 CUTLASS GEMM with blockwise scaling + * 3. FP8 E4M3 output (direct, no conversion) + * + * This is the "true" FP8 GEMM for FP8 models (Llama 3.1 FP8, etc.) + * where weights and activations are already in FP8 format. + * + * Implementation based on CUTLASS example 87a: + * "87a_blackwell_geforce_fp8_bf16_gemm_blockwise" + * Modified for FP8 output instead of BF16. + */ + +#include +#include +#include +#include + +// Enable FP8 SM120 +#define PYGPUKIT_ENABLE_FP8_SM120 + +// Only compile for SM120+ AND when explicitly enabled +#if (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) && defined(PYGPUKIT_ENABLE_FP8_SM120) + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/blockwise_scale_layout.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/device_memory.h" + +// Alignment patch for Issue #2902 workaround +#define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 +#include "aligned_copy_sm120.cuh" + +using namespace cute; + +namespace pygpukit { +namespace ops { +namespace fp8_fp8_gemm_sm120 { + +// ============================================================================ +// GEMM Configuration: FP8 E4M3 x FP8 E4M3 -> FP8 E4M3 with blockwise scaling +// ============================================================================ + +// A matrix: FP8 E4M3, RowMajor +using ElementA = cutlass::float_e4m3_t; +using LayoutATag = cutlass::layout::RowMajor; +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + +// B matrix: FP8 E4M3, ColumnMajor +using ElementB = cutlass::float_e4m3_t; +using LayoutBTag = cutlass::layout::ColumnMajor; +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + +// Output: FP8 E4M3 (Pure FP8 output!) +using ElementC = cutlass::float_e4m3_t; +using ElementD = cutlass::float_e4m3_t; +using LayoutCTag = cutlass::layout::RowMajor; +using LayoutDTag = cutlass::layout::RowMajor; +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentD = AlignmentC; + +// Accumulator type (still float for precision) +using ElementAccumulator = float; +using ElementCompute = float; + +// SM120 GeForce architecture with TensorOp +using ArchTag = cutlass::arch::Sm120; +using OperatorClass = cutlass::arch::OpClassTensorOp; + +// MMA and Cluster Tile Shapes +using MmaTileShape_MNK = Shape<_128, _128, _128>; +using ClusterShape_MNK = Shape<_1, _1, _1>; // GeForce: no cluster support + +// Scale configuration (trivial blockwise scaling from example 87a) +using ScaleConfig = decltype(cutlass::detail::sm120_trivial_blockwise_scale_config(MmaTileShape_MNK{})); +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + +// Epilogue - outputs FP8 +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto +>::CollectiveOp; + +// Mainloop with scale factor layouts +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto +>::CollectiveOp; + +// GEMM Kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + void // Default CLC scheduler +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Stride and Layout types +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// ============================================================================ +// Scale factor initialization (unity for now, can be extended for per-tensor/block) +// ============================================================================ + +__global__ void fill_scale_factors_unity_kernel( + float* __restrict__ scales, + size_t num_scales +) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= num_scales) return; + scales[idx] = 1.0f; +} + +// ============================================================================ +// FP8 -> FP8 GEMM Entry Point +// ============================================================================ + +cudaError_t gemm_fp8_fp8( + const cutlass::float_e4m3_t* A, // [M, K] FP8 input (RowMajor) + const cutlass::float_e4m3_t* B, // [K, N] FP8 input (ColumnMajor, pre-transposed) + cutlass::float_e4m3_t* D, // [M, N] FP8 output + int M, int N, int K, + float alpha, + float beta, + cudaStream_t stream +) { + // Sizes + int64_t size_D = static_cast(M) * N; + + // Allocate C buffer for epilogue (even with beta=0, CUTLASS needs valid pointer) + cutlass::device_memory::allocation buf_C(size_D); + auto* d_C = buf_C.get(); + + // Calculate scale factor sizes using ScaleConfig + auto problem_shape = cute::make_shape(M, N, K, 1); + LayoutSFA layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(problem_shape); + LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(problem_shape); + + size_t sfa_size = size(filter_zeros(layout_SFA)); + size_t sfb_size = size(filter_zeros(layout_SFB)); + + // Pad to 32 floats (128 bytes) for TMA alignment + size_t sfa_padded = std::max(sfa_size, size_t(32)); + size_t sfb_padded = std::max(sfb_size, size_t(32)); + + cutlass::device_memory::allocation buf_SFA(sfa_padded); + cutlass::device_memory::allocation buf_SFB(sfb_padded); + + auto* d_SFA = buf_SFA.get(); + auto* d_SFB = buf_SFB.get(); + + // Fill scale factors with 1.0 + int threads = 256; + int blocks_SFA_fill = (sfa_padded + threads - 1) / threads; + int blocks_SFB_fill = (sfb_padded + threads - 1) / threads; + fill_scale_factors_unity_kernel<<>>(d_SFA, sfa_padded); + fill_scale_factors_unity_kernel<<>>(d_SFB, sfb_padded); + + // Build strides + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1)); + StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, 1)); + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, 1)); + + // Build CUTLASS arguments + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + { // Mainloop arguments + A, stride_a, + B, stride_b, + d_SFA, layout_SFA, + d_SFB, layout_SFB + }, + { // Epilogue arguments + {}, // epilogue.thread + d_C, stride_c, + D, stride_d + } + }; + + // Set alpha/beta + arguments.epilogue.thread.alpha = alpha; + arguments.epilogue.thread.beta = beta; + + // Instantiate and run GEMM + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8_FP8 GEMM SM120] can_implement failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8_FP8 GEMM SM120] initialize failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + status = gemm_op.run(stream); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8_FP8 GEMM SM120] run failed: %d\n", static_cast(status)); + return cudaErrorLaunchFailure; + } + + return cudaSuccess; +} + +// Wrapper for raw uint8_t pointers (for Python binding convenience) +cudaError_t gemm_fp8_fp8_raw( + const uint8_t* A, // [M, K] FP8 as raw bytes + const uint8_t* B, // [K, N] FP8 as raw bytes (ColumnMajor) + uint8_t* D, // [M, N] FP8 as raw bytes + int M, int N, int K, + float alpha, + float beta, + cudaStream_t stream +) { + return gemm_fp8_fp8( + reinterpret_cast(A), + reinterpret_cast(B), + reinterpret_cast(D), + M, N, K, alpha, beta, stream + ); +} + +// ============================================================================ +// Get scale factor sizes for a given problem size +// ============================================================================ + +void get_scale_sizes(int M, int N, int K, size_t* sfa_size, size_t* sfb_size) { + auto problem_shape = cute::make_shape(M, N, K, 1); + LayoutSFA layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(problem_shape); + LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(problem_shape); + + *sfa_size = size(filter_zeros(layout_SFA)); + *sfb_size = size(filter_zeros(layout_SFB)); +} + +// ============================================================================ +// FP8 -> FP8 GEMM with Blockwise Scaling +// ============================================================================ + +cudaError_t gemm_fp8_fp8_blockwise( + const cutlass::float_e4m3_t* A, // [M, K] FP8 input (RowMajor) + const cutlass::float_e4m3_t* B, // [K, N] FP8 input (ColumnMajor, pre-transposed) + cutlass::float_e4m3_t* D, // [M, N] FP8 output + const float* scale_A, // Scale factors for A + const float* scale_B, // Scale factors for B + int M, int N, int K, + float alpha, + float beta, + cudaStream_t stream +) { + // Sizes + int64_t size_D = static_cast(M) * N; + + // Allocate C buffer for epilogue + cutlass::device_memory::allocation buf_C(size_D); + auto* d_C = buf_C.get(); + + // Calculate scale factor layouts + auto problem_shape = cute::make_shape(M, N, K, 1); + LayoutSFA layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(problem_shape); + LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(problem_shape); + + // Build strides + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1)); + StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, 1)); + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, 1)); + + // Build CUTLASS arguments with user-provided scale factors + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + { // Mainloop arguments + A, stride_a, + B, stride_b, + scale_A, layout_SFA, + scale_B, layout_SFB + }, + { // Epilogue arguments + {}, // epilogue.thread + d_C, stride_c, + D, stride_d + } + }; + + // Set alpha/beta + arguments.epilogue.thread.alpha = alpha; + arguments.epilogue.thread.beta = beta; + + // Instantiate and run GEMM + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8_FP8 Blockwise GEMM SM120] can_implement failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8_FP8 Blockwise GEMM SM120] initialize failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + status = gemm_op.run(stream); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8_FP8 Blockwise GEMM SM120] run failed: %d\n", static_cast(status)); + return cudaErrorLaunchFailure; + } + + return cudaSuccess; +} + +// Wrapper for raw uint8_t pointers +cudaError_t gemm_fp8_fp8_blockwise_raw( + const uint8_t* A, + const uint8_t* B, + uint8_t* D, + const float* scale_A, + const float* scale_B, + int M, int N, int K, + float alpha, + float beta, + cudaStream_t stream +) { + return gemm_fp8_fp8_blockwise( + reinterpret_cast(A), + reinterpret_cast(B), + reinterpret_cast(D), + scale_A, scale_B, + M, N, K, alpha, beta, stream + ); +} + +bool is_available() { + int device_id = 0; + cudaGetDevice(&device_id); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device_id); + return (props.major * 10 + props.minor) >= 120; +} + +} // namespace fp8_fp8_gemm_sm120 +} // namespace ops +} // namespace pygpukit + +// Extern C for linking +extern "C" { + cudaError_t pygpukit_gemm_fp8_fp8_sm120( + const uint8_t* A, const uint8_t* B, uint8_t* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return pygpukit::ops::fp8_fp8_gemm_sm120::gemm_fp8_fp8_raw( + A, B, D, M, N, K, alpha, beta, stream + ); + } + + bool pygpukit_fp8_fp8_sm120_available() { + return pygpukit::ops::fp8_fp8_gemm_sm120::is_available(); + } + + // Blockwise scaled version + cudaError_t pygpukit_gemm_fp8_fp8_blockwise_sm120( + const uint8_t* A, const uint8_t* B, uint8_t* D, + const float* scale_A, const float* scale_B, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return pygpukit::ops::fp8_fp8_gemm_sm120::gemm_fp8_fp8_blockwise_raw( + A, B, D, scale_A, scale_B, M, N, K, alpha, beta, stream + ); + } + + // Get scale factor sizes for a given problem + void pygpukit_fp8_fp8_get_scale_sizes( + int M, int N, int K, + size_t* sfa_size, size_t* sfb_size + ) { + pygpukit::ops::fp8_fp8_gemm_sm120::get_scale_sizes(M, N, K, sfa_size, sfb_size); + } +} + +#else // !SM120 + +namespace pygpukit { +namespace ops { +namespace fp8_fp8_gemm_sm120 { + +cudaError_t gemm_fp8_fp8_raw( + const uint8_t* A, const uint8_t* B, uint8_t* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream +) { + return cudaErrorNotSupported; +} + +bool is_available() { + return false; +} + +} // namespace fp8_fp8_gemm_sm120 +} // namespace ops +} // namespace pygpukit + +extern "C" { + cudaError_t pygpukit_gemm_fp8_fp8_sm120( + const uint8_t* A, const uint8_t* B, uint8_t* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return cudaErrorNotSupported; + } + + bool pygpukit_fp8_fp8_sm120_available() { + return false; + } + + cudaError_t pygpukit_gemm_fp8_fp8_blockwise_sm120( + const uint8_t* A, const uint8_t* B, uint8_t* D, + const float* scale_A, const float* scale_B, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return cudaErrorNotSupported; + } + + void pygpukit_fp8_fp8_get_scale_sizes( + int M, int N, int K, + size_t* sfa_size, size_t* sfb_size + ) { + *sfa_size = 0; + *sfb_size = 0; + } +} + +#endif diff --git a/native/ops/matmul/matmul_fp8_sm100.cu b/native/ops/matmul/matmul_fp8_sm100.cu new file mode 100644 index 0000000..5b34707 --- /dev/null +++ b/native/ops/matmul/matmul_fp8_sm100.cu @@ -0,0 +1,372 @@ +/** + * FP8 GEMM implementation for SM100 (Blackwell datacenter) + * + * Path: + * 1. FP32 input + * 2. FP8 quantization with blockwise scaling + * 3. FP8 CUTLASS GEMM (SM100 tcgen05) + * 4. FP32 output + * + * Based on CUTLASS example 81: blackwell_gemm_blockwise + * + * This serves as potential fallback for SM120 (Blackwell GeForce). + * SM100 and SM120 are both Blackwell architecture - the kernel might work. + */ + +#include +#include +#include +#include +#include + +// Only compile for SM100+ +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/blockwise_scale_layout.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/device_memory.h" + +using namespace cute; + +namespace pygpukit { +namespace ops { +namespace fp8_gemm_sm100 { + +// ============================================================================ +// GEMM Configuration: FP8 E4M3 x FP8 E4M3 -> FP32 with blockwise scaling +// Based on CUTLASS example 81 +// ============================================================================ + +// A matrix: FP8 E4M3, RowMajor +using ElementA = cutlass::float_e4m3_t; +using LayoutA = cutlass::layout::RowMajor; +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // 16 + +// B matrix: FP8 E4M3, ColumnMajor +using ElementB = cutlass::float_e4m3_t; +using LayoutB = cutlass::layout::ColumnMajor; +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // 16 + +// Output: FP32 (we use bfloat16 internally then convert) +using ElementC = cutlass::bfloat16_t; +using ElementD = cutlass::bfloat16_t; +using LayoutC = cutlass::layout::RowMajor; +using LayoutD = cutlass::layout::RowMajor; +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentD = AlignmentC; + +// Accumulator type +using ElementAccumulator = float; +using ElementCompute = float; + +// SM100 Blackwell architecture +using ArchTag = cutlass::arch::Sm100; +using OperatorClass = cutlass::arch::OpClassTensorOp; + +// Tile and cluster shapes - using smaller tiles for better compatibility +using MmaTileShape_MNK = Shape<_128, _128, _128>; +using ClusterShape_MNK = Shape<_1, _1, _1>; + +// Scale config for blockwise scaling +using ScaleConfig = decltype(cutlass::detail::sm100_trivial_blockwise_scale_config(MmaTileShape_MNK{})); +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + +// Epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto +>::CollectiveOp; + +// Mainloop with blockwise scaling +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + cutlass::gemm::KernelScheduleSm100Blockwise +>::CollectiveOp; + +// GEMM Kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + void +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// ============================================================================ +// FP32 -> FP8 Quantization +// ============================================================================ + +constexpr float FP8_E4M3_MAX = 448.0f; + +__device__ __forceinline__ +uint8_t float_to_fp8_e4m3_scaled(float val, float inv_scale) { + val = val * inv_scale; + val = fminf(fmaxf(val, -FP8_E4M3_MAX), FP8_E4M3_MAX); + + if (fabsf(val) < 1e-7f) return 0; + + uint32_t bits = __float_as_uint(val); + uint8_t sign = (bits >> 24) & 0x80; + int exp = ((bits >> 23) & 0xFF) - 127 + 7; + uint32_t mant = bits & 0x7FFFFF; + + if (exp <= 0) return sign; + if (exp >= 15) return sign | 0x7E; + + return sign | (static_cast(exp) << 3) | static_cast(mant >> 20); +} + +__global__ void quantize_fp32_to_fp8_kernel( + const float* __restrict__ input, + cutlass::float_e4m3_t* __restrict__ output, + int64_t num_elements +) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= num_elements) return; + + uint8_t fp8 = float_to_fp8_e4m3_scaled(input[idx], 1.0f); + output[idx] = cutlass::float_e4m3_t::bitcast(fp8); +} + +__global__ void transpose_quantize_fp32_to_fp8_kernel( + const float* __restrict__ input, + cutlass::float_e4m3_t* __restrict__ output, + int K, int N +) { + int k = blockIdx.y * blockDim.y + threadIdx.y; + int n = blockIdx.x * blockDim.x + threadIdx.x; + + if (k >= K || n >= N) return; + + float val = input[k * N + n]; + uint8_t fp8 = float_to_fp8_e4m3_scaled(val, 1.0f); + output[k + n * K] = cutlass::float_e4m3_t::bitcast(fp8); +} + +__global__ void fill_scale_factors_unity_kernel( + float* __restrict__ scales, + size_t num_scales +) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= num_scales) return; + scales[idx] = 1.0f; +} + +__global__ void bf16_to_fp32_kernel( + const cutlass::bfloat16_t* __restrict__ input, + float* __restrict__ output, + int64_t num_elements +) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= num_elements) return; + output[idx] = static_cast(input[idx]); +} + +// ============================================================================ +// FP8 GEMM Entry Point +// ============================================================================ + +cudaError_t gemm_fp8( + const float* A, + const float* B, + float* D, + int M, int N, int K, + float alpha, + float beta, + cudaStream_t stream +) { + // Sizes + int64_t size_A = static_cast(M) * K; + int64_t size_B = static_cast(K) * N; + int64_t size_D = static_cast(M) * N; + + // Allocate FP8 buffers + cutlass::device_memory::allocation buf_A_fp8(size_A); + cutlass::device_memory::allocation buf_B_fp8(size_B); + cutlass::device_memory::allocation buf_C_bf16(size_D); + cutlass::device_memory::allocation buf_D_bf16(size_D); + + auto* d_A_fp8 = buf_A_fp8.get(); + auto* d_B_fp8 = buf_B_fp8.get(); + auto* d_C_bf16 = buf_C_bf16.get(); + auto* d_D_bf16 = buf_D_bf16.get(); + + // Scale factor sizes + auto problem_shape = cute::make_shape(M, N, K, 1); + LayoutSFA layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(problem_shape); + LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(problem_shape); + + size_t sfa_size = size(filter_zeros(layout_SFA)); + size_t sfb_size = size(filter_zeros(layout_SFB)); + + cutlass::device_memory::allocation buf_SFA(sfa_size); + cutlass::device_memory::allocation buf_SFB(sfb_size); + + auto* d_SFA = buf_SFA.get(); + auto* d_SFB = buf_SFB.get(); + + // Quantize + int threads = 256; + int blocks_A = (size_A + threads - 1) / threads; + + quantize_fp32_to_fp8_kernel<<>>(A, d_A_fp8, size_A); + + dim3 block_B(16, 16); + dim3 grid_B((N + 15) / 16, (K + 15) / 16); + transpose_quantize_fp32_to_fp8_kernel<<>>(B, d_B_fp8, K, N); + + // Fill scale factors + int blocks_SFA = (sfa_size + threads - 1) / threads; + int blocks_SFB = (sfb_size + threads - 1) / threads; + fill_scale_factors_unity_kernel<<>>(d_SFA, sfa_size); + fill_scale_factors_unity_kernel<<>>(d_SFB, sfb_size); + + cudaError_t err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) return err; + + // Build strides + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1)); + StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, 1)); + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, 1)); + + // Build arguments + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + {d_A_fp8, stride_a, d_B_fp8, stride_b, d_SFA, layout_SFA, d_SFB, layout_SFB}, + {{alpha, beta}, d_C_bf16, stride_c, d_D_bf16, stride_d} + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8 GEMM SM100] can_implement failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8 GEMM SM100] initialize failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + status = gemm_op.run(stream); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8 GEMM SM100] run failed: %d\n", static_cast(status)); + return cudaErrorLaunchFailure; + } + + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) { + fprintf(stderr, "[FP8 GEMM SM100] sync failed: %s\n", cudaGetErrorString(err)); + return err; + } + + // Convert BF16 to FP32 + int blocks_D = (size_D + threads - 1) / threads; + bf16_to_fp32_kernel<<>>(d_D_bf16, D, size_D); + + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) return err; + + return cudaSuccess; +} + +bool is_available() { + int device_id = 0; + cudaGetDevice(&device_id); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device_id); + // SM100+ (Blackwell datacenter and consumer) + return (props.major * 10 + props.minor) >= 100; +} + +} // namespace fp8_gemm_sm100 +} // namespace ops +} // namespace pygpukit + +extern "C" { + cudaError_t pygpukit_gemm_fp8_sm100( + const float* A, const float* B, float* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return pygpukit::ops::fp8_gemm_sm100::gemm_fp8(A, B, D, M, N, K, alpha, beta, stream); + } + + bool pygpukit_fp8_sm100_available() { + return pygpukit::ops::fp8_gemm_sm100::is_available(); + } +} + +#else // !SM100 + +namespace pygpukit { +namespace ops { +namespace fp8_gemm_sm100 { + +cudaError_t gemm_fp8( + const float* A, const float* B, float* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream +) { + return cudaErrorNotSupported; +} + +bool is_available() { + return false; +} + +} // namespace fp8_gemm_sm100 +} // namespace ops +} // namespace pygpukit + +extern "C" { + cudaError_t pygpukit_gemm_fp8_sm100( + const float* A, const float* B, float* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return cudaErrorNotSupported; + } + + bool pygpukit_fp8_sm100_available() { + return false; + } +} + +#endif diff --git a/native/ops/matmul/matmul_fp8_sm90.cu b/native/ops/matmul/matmul_fp8_sm90.cu new file mode 100644 index 0000000..c2eef4e --- /dev/null +++ b/native/ops/matmul/matmul_fp8_sm90.cu @@ -0,0 +1,400 @@ +/** + * FP8 GEMM implementation for SM90 (Hopper) + * + * Path: + * 1. FP32 input + * 2. FP8 quantization with per-tensor scaling + * 3. FP8 CUTLASS GEMM (Hopper TMA + WGMMA) + * 4. FP32 output + * + * Based on CUTLASS example 54: hopper_fp8_warp_specialized_gemm + * + * This serves as fallback for SM120 (Blackwell GeForce) until CUTLASS + * fixes the blockwise scaling alignment bug (#2902). + */ + +#include +#include +#include +#include +#include + +// Only compile for SM90+ +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/device_memory.h" + +using namespace cute; + +namespace pygpukit { +namespace ops { +namespace fp8_gemm_sm90 { + +// ============================================================================ +// GEMM Configuration: FP8 E4M3 x FP8 E4M3 -> FP32 with per-tensor scaling +// Based on CUTLASS example 54 +// ============================================================================ + +// A matrix: FP8 E4M3, RowMajor +using ElementA = cutlass::float_e4m3_t; +using LayoutATag = cutlass::layout::RowMajor; +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // 16 + +// B matrix: FP8 E4M3, ColumnMajor +using ElementB = cutlass::float_e4m3_t; +using LayoutBTag = cutlass::layout::ColumnMajor; +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // 16 + +// Output: FP32 (we'll convert internally) +using ElementC = float; +using ElementD = float; +using LayoutCTag = cutlass::layout::RowMajor; +using LayoutDTag = cutlass::layout::RowMajor; +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // 4 +constexpr int AlignmentD = AlignmentC; + +// Accumulator type +using ElementAccumulator = float; +using ElementCompute = float; + +// SM90 Hopper architecture +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; + +// Tile and cluster shapes for Hopper +using TileShape = Shape<_128, _128, _64>; +using ClusterShape = Shape<_1, _1, _1>; // Simple 1x1x1 cluster for compatibility + +// Kernel schedule +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + +// Epilogue (simple linear combination: D = alpha * A @ B + beta * C) +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + EpilogueSchedule +>::CollectiveOp; + +// Mainloop +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule +>::CollectiveOp; + +// GEMM Kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// ============================================================================ +// FP32 -> FP8 Quantization with per-tensor scaling +// ============================================================================ + +constexpr float FP8_E4M3_MAX = 448.0f; + +// Find max absolute value in tensor (for computing scale) +__global__ void find_absmax_kernel( + const float* __restrict__ input, + float* __restrict__ absmax, + int64_t num_elements +) { + __shared__ float shared_max[256]; + + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + float local_max = 0.0f; + + // Grid-stride loop + for (int64_t i = idx; i < num_elements; i += static_cast(gridDim.x) * blockDim.x) { + local_max = fmaxf(local_max, fabsf(input[i])); + } + + shared_max[threadIdx.x] = local_max; + __syncthreads(); + + // Reduction within block + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]); + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + atomicMax(reinterpret_cast(absmax), + __float_as_int(shared_max[0])); + } +} + +// Quantize FP32 to FP8 with scale +__device__ __forceinline__ +uint8_t float_to_fp8_e4m3_scaled(float val, float inv_scale) { + val = val * inv_scale; + val = fminf(fmaxf(val, -FP8_E4M3_MAX), FP8_E4M3_MAX); + + if (fabsf(val) < 1e-7f) return 0; + + uint32_t bits = __float_as_uint(val); + uint8_t sign = (bits >> 24) & 0x80; + int exp = ((bits >> 23) & 0xFF) - 127 + 7; // FP8 E4M3 bias = 7 + uint32_t mant = bits & 0x7FFFFF; + + if (exp <= 0) return sign; + if (exp >= 15) return sign | 0x7E; + + return sign | (static_cast(exp) << 3) | static_cast(mant >> 20); +} + +__global__ void quantize_fp32_to_fp8_scaled_kernel( + const float* __restrict__ input, + cutlass::float_e4m3_t* __restrict__ output, + float inv_scale, + int64_t num_elements +) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= num_elements) return; + + uint8_t fp8 = float_to_fp8_e4m3_scaled(input[idx], inv_scale); + output[idx] = cutlass::float_e4m3_t::bitcast(fp8); +} + +// Transpose and quantize B from RowMajor [K,N] to ColumnMajor [K,N] +__global__ void transpose_quantize_fp32_to_fp8_kernel( + const float* __restrict__ input, // [K, N] RowMajor + cutlass::float_e4m3_t* __restrict__ output, // [K, N] ColumnMajor + float inv_scale, + int K, int N +) { + int k = blockIdx.y * blockDim.y + threadIdx.y; + int n = blockIdx.x * blockDim.x + threadIdx.x; + + if (k >= K || n >= N) return; + + float val = input[k * N + n]; + uint8_t fp8 = float_to_fp8_e4m3_scaled(val, inv_scale); + output[k + n * K] = cutlass::float_e4m3_t::bitcast(fp8); +} + +// ============================================================================ +// FP8 GEMM Entry Point +// ============================================================================ + +cudaError_t gemm_fp8( + const float* A, // [M, K] FP32 input + const float* B, // [K, N] FP32 input + float* D, // [M, N] FP32 output + int M, int N, int K, + float alpha, + float beta, + cudaStream_t stream +) { + // Sizes + int64_t size_A = static_cast(M) * K; + int64_t size_B = static_cast(K) * N; + int64_t size_D = static_cast(M) * N; + + // Allocate FP8 buffers + cutlass::device_memory::allocation buf_A_fp8(size_A); + cutlass::device_memory::allocation buf_B_fp8(size_B); + cutlass::device_memory::allocation buf_C(size_D); // For beta * C + + auto* d_A_fp8 = buf_A_fp8.get(); + auto* d_B_fp8 = buf_B_fp8.get(); + auto* d_C = buf_C.get(); + + // Compute scale factors (find absmax for each tensor) + cutlass::device_memory::allocation buf_absmax_A(1); + cutlass::device_memory::allocation buf_absmax_B(1); + + cudaMemsetAsync(buf_absmax_A.get(), 0, sizeof(float), stream); + cudaMemsetAsync(buf_absmax_B.get(), 0, sizeof(float), stream); + + int threads = 256; + int blocks_A = std::min(1024, static_cast((size_A + threads - 1) / threads)); + int blocks_B = std::min(1024, static_cast((size_B + threads - 1) / threads)); + + find_absmax_kernel<<>>(A, buf_absmax_A.get(), size_A); + find_absmax_kernel<<>>(B, buf_absmax_B.get(), size_B); + + // Copy absmax to host to compute scales + float absmax_A = 0.0f, absmax_B = 0.0f; + cudaMemcpyAsync(&absmax_A, buf_absmax_A.get(), sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(&absmax_B, buf_absmax_B.get(), sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + + // Compute scales: scale = absmax / FP8_MAX, inv_scale = FP8_MAX / absmax + float scale_A = (absmax_A > 0.0f) ? (absmax_A / FP8_E4M3_MAX) : 1.0f; + float scale_B = (absmax_B > 0.0f) ? (absmax_B / FP8_E4M3_MAX) : 1.0f; + float inv_scale_A = (absmax_A > 0.0f) ? (FP8_E4M3_MAX / absmax_A) : 1.0f; + float inv_scale_B = (absmax_B > 0.0f) ? (FP8_E4M3_MAX / absmax_B) : 1.0f; + + // Quantize A (keep RowMajor) + int blocks_A_q = (size_A + threads - 1) / threads; + quantize_fp32_to_fp8_scaled_kernel<<>>( + A, d_A_fp8, inv_scale_A, size_A + ); + + // Quantize and transpose B (RowMajor -> ColumnMajor) + dim3 block_B(16, 16); + dim3 grid_B((N + 15) / 16, (K + 15) / 16); + transpose_quantize_fp32_to_fp8_kernel<<>>( + B, d_B_fp8, inv_scale_B, K, N + ); + + // Initialize C buffer (for beta=0, we can skip) + if (beta != 0.0f) { + cudaMemsetAsync(d_C, 0, size_D * sizeof(float), stream); + } + + cudaError_t err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) return err; + + // Build strides + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1)); + StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, 1)); + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, 1)); + + // Adjusted alpha to account for FP8 scaling + // Result = scale_A * scale_B * (A_fp8 @ B_fp8) + // So we multiply alpha by scale_A * scale_B + float adjusted_alpha = alpha * scale_A * scale_B; + + // Build CUTLASS arguments + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + {d_A_fp8, stride_a, d_B_fp8, stride_b}, + {{adjusted_alpha, beta}, d_C, stride_c, D, stride_d} + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8 GEMM SM90] can_implement failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8 GEMM SM90] initialize failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + status = gemm_op.run(stream); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[FP8 GEMM SM90] run failed: %d\n", static_cast(status)); + return cudaErrorLaunchFailure; + } + + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) { + fprintf(stderr, "[FP8 GEMM SM90] sync failed: %s\n", cudaGetErrorString(err)); + return err; + } + + return cudaSuccess; +} + +bool is_available() { + int device_id = 0; + cudaGetDevice(&device_id); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device_id); + // SM90 only (Hopper) - TMA-based kernels may not work on Blackwell (SM100/SM120) + // Blackwell has different TMA behavior that causes CUTLASS initialization failures + int sm = props.major * 10 + props.minor; + return (sm >= 90 && sm < 100); +} + +} // namespace fp8_gemm_sm90 +} // namespace ops +} // namespace pygpukit + +// Extern C for linking +extern "C" { + cudaError_t pygpukit_gemm_fp8_sm90( + const float* A, const float* B, float* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return pygpukit::ops::fp8_gemm_sm90::gemm_fp8(A, B, D, M, N, K, alpha, beta, stream); + } + + bool pygpukit_fp8_sm90_available() { + return pygpukit::ops::fp8_gemm_sm90::is_available(); + } +} + +#else // !SM90 + +namespace pygpukit { +namespace ops { +namespace fp8_gemm_sm90 { + +cudaError_t gemm_fp8( + const float* A, const float* B, float* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream +) { + return cudaErrorNotSupported; +} + +bool is_available() { + return false; +} + +} // namespace fp8_gemm_sm90 +} // namespace ops +} // namespace pygpukit + +extern "C" { + cudaError_t pygpukit_gemm_fp8_sm90( + const float* A, const float* B, float* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return cudaErrorNotSupported; + } + + bool pygpukit_fp8_sm90_available() { + return false; + } +} + +#endif diff --git a/native/ops/matmul/matmul_nvf4_bf16_sm120.cu b/native/ops/matmul/matmul_nvf4_bf16_sm120.cu new file mode 100644 index 0000000..25e9261 --- /dev/null +++ b/native/ops/matmul/matmul_nvf4_bf16_sm120.cu @@ -0,0 +1,583 @@ +/** + * NVF4 GEMM implementation for SM120 (Blackwell GeForce) with BF16 I/O + * + * Based on CUTLASS example 79a: blackwell_geforce_nvfp4_bf16_gemm + * + * Data Flow: + * BF16 input -> NVF4 (4-bit) quantize with block scaling -> CUTLASS GEMM -> BF16 output + * + * NVF4 (float_e2m1_t) is a 4-bit format with 2-bit exponent and 1-bit mantissa. + * This provides 2x memory bandwidth compared to FP8, making it ideal for + * memory-bound LLM inference workloads. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +// Enable NVF4 SM120 +#define PYGPUKIT_ENABLE_NVF4_SM120 + +// Only compile for SM120+ +#if (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) && defined(PYGPUKIT_ENABLE_NVF4_SM120) + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +using namespace cute; + +namespace pygpukit { +namespace ops { +namespace nvf4_bf16_gemm_sm120 { + +// ============================================================================ +// GEMM Configuration (from example 79a) +// ============================================================================ + +// A matrix configuration +using ElementA = cutlass::nv_float4_t; // NVF4 wrapper type +using LayoutATag = cutlass::layout::RowMajor; +constexpr int AlignmentA = 32; // Memory access granularity + +// B matrix configuration +using ElementB = cutlass::nv_float4_t; // NVF4 wrapper type +using LayoutBTag = cutlass::layout::ColumnMajor; +constexpr int AlignmentB = 32; + +// C/D matrix configuration (BF16 output) +using ElementC = cutlass::bfloat16_t; +using ElementD = cutlass::bfloat16_t; +using LayoutCTag = cutlass::layout::RowMajor; +using LayoutDTag = cutlass::layout::RowMajor; +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // 8 +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // 8 + +// Kernel config +using ElementAccumulator = float; +using ArchTag = cutlass::arch::Sm120; +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + +// Tile shapes - K=256 is recommended for NVF4 in CUTLASS tests +using ThreadBlockShape = Shape<_128, _128, _256>; +using ClusterShape = Shape<_1, _1, _1>; // GeForce: no cluster support + +// Epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto +>::CollectiveOp; + +// Mainloop - Pingpong schedule with auto stage count (explicit 3 causes init failure) +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedPingpong +>::CollectiveOp; + +// GEMM Kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + void +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Types for data layout +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; +using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + +// Data types for raw storage +using DataTypeA = typename ElementA::DataType; // float_e2m1_t +using ScaleFactorType = typename ElementA::ScaleFactorType; // float_ue4m3_t + +// ============================================================================ +// BF16 -> NVF4 Quantization with Block Scaling +// ============================================================================ + +// NVF4 E2M1 range: [-6.0, 6.0] +constexpr float NVF4_MAX = 6.0f; + +// Convert float to NVF4 E2M1 (4-bit) - HOST version +inline uint8_t bf16_to_nvf4_e2m1_host(float val) { + // E2M1 representable values: 0, 0.5, 1, 1.5, 2, 3, 4, 6 (and negatives) + if (std::abs(val) < 0.25f) return 0; // Zero + + uint8_t sign = (val < 0) ? 0x8 : 0x0; + val = std::abs(val); + val = std::min(val, NVF4_MAX); + + // Quantize to nearest E2M1 value + uint8_t code; + if (val < 0.75f) code = 1; // 0.5 + else if (val < 1.25f) code = 2; // 1.0 + else if (val < 1.75f) code = 3; // 1.5 + else if (val < 2.5f) code = 4; // 2.0 + else if (val < 3.5f) code = 5; // 3.0 + else if (val < 5.0f) code = 6; // 4.0 + else code = 7; // 6.0 + + return sign | code; +} + +// ============================================================================ +// Branchless BF16 -> NVF4 Quantization +// ============================================================================ +// Uses comparison accumulation - faster than LUT on modern GPUs +// LUT approaches tested but slower due to constant memory latency + +__device__ __forceinline__ +uint8_t bf16_to_nvf4_e2m1(float val) { + // E2M1 representable values: 0, 0.5, 1, 1.5, 2, 3, 4, 6 (and negatives) + float absval = fabsf(val); + uint8_t sign = (val < 0.0f) ? 0x8 : 0x0; + + // Branchless: count how many thresholds we exceed + uint8_t code = 0; + code += (absval >= 0.25f); + code += (absval >= 0.75f); + code += (absval >= 1.25f); + code += (absval >= 1.75f); + code += (absval >= 2.5f); + code += (absval >= 3.5f); + code += (absval >= 5.0f); + + return sign | code; +} + +// ============================================================================ +// GPU-side BF16 -> NVF4 Quantization Kernels (Unit Scale) +// ============================================================================ + +// Vectorized GPU quantization: BF16 [M, K] RowMajor -> NVF4 packed (unit scale) +// Each thread processes 8 BF16 elements -> 4 output bytes using uint4 loads +// Uses branchless float comparison (faster than LUT - see benchmark notes) +__global__ void quantize_A_gpu_kernel( + const nv_bfloat16* __restrict__ input, // [M, K] RowMajor BF16 + uint8_t* __restrict__ output, // Packed NVF4 (size = M*K/2) + int M, int K +) { + // Each thread handles 8 elements (4 output bytes) + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total_quads = (M * K) / 8; + if (idx >= total_quads) return; + + // Vectorized load: 8 BF16 = 16 bytes = uint4 + const uint4* input_vec = reinterpret_cast(input); + uint4 data = input_vec[idx]; + + // Extract BF16 values and convert to float + nv_bfloat16 bf0, bf1, bf2, bf3, bf4, bf5, bf6, bf7; + memcpy(&bf0, reinterpret_cast(&data.x), sizeof(nv_bfloat16)); + memcpy(&bf1, reinterpret_cast(&data.x) + 1, sizeof(nv_bfloat16)); + memcpy(&bf2, reinterpret_cast(&data.y), sizeof(nv_bfloat16)); + memcpy(&bf3, reinterpret_cast(&data.y) + 1, sizeof(nv_bfloat16)); + memcpy(&bf4, reinterpret_cast(&data.z), sizeof(nv_bfloat16)); + memcpy(&bf5, reinterpret_cast(&data.z) + 1, sizeof(nv_bfloat16)); + memcpy(&bf6, reinterpret_cast(&data.w), sizeof(nv_bfloat16)); + memcpy(&bf7, reinterpret_cast(&data.w) + 1, sizeof(nv_bfloat16)); + + // Quantize using branchless float comparison + uint8_t q0 = bf16_to_nvf4_e2m1(__bfloat162float(bf0)); + uint8_t q1 = bf16_to_nvf4_e2m1(__bfloat162float(bf1)); + uint8_t q2 = bf16_to_nvf4_e2m1(__bfloat162float(bf2)); + uint8_t q3 = bf16_to_nvf4_e2m1(__bfloat162float(bf3)); + uint8_t q4 = bf16_to_nvf4_e2m1(__bfloat162float(bf4)); + uint8_t q5 = bf16_to_nvf4_e2m1(__bfloat162float(bf5)); + uint8_t q6 = bf16_to_nvf4_e2m1(__bfloat162float(bf6)); + uint8_t q7 = bf16_to_nvf4_e2m1(__bfloat162float(bf7)); + + // Pack into 4 bytes and write as uint32 + uint32_t packed = ((q1 << 4) | (q0 & 0x0F)) + | (((q3 << 4) | (q2 & 0x0F)) << 8) + | (((q5 << 4) | (q4 & 0x0F)) << 16) + | (((q7 << 4) | (q6 & 0x0F)) << 24); + + reinterpret_cast(output)[idx] = packed; +} + +// GPU quantization: BF16 [K, N] RowMajor -> NVF4 [N, K] ColumnMajor packed (unit scale) +// Vectorized version using shared memory transpose for coalesced access +// TILE_K=64, TILE_N=32: each block processes 64x32 tile, outputs 32x32 packed bytes +__global__ void quantize_B_gpu_kernel( + const nv_bfloat16* __restrict__ input, // [K, N] RowMajor BF16 + uint8_t* __restrict__ output, // Packed NVF4 ColMajor (size = N*K/2) + int K, int N +) { + constexpr int TILE_K = 64; + constexpr int TILE_N = 32; + + // Shared memory: TILE_K x TILE_N with padding to avoid bank conflicts + __shared__ uint8_t smem_q[TILE_K][TILE_N + 4]; + + int block_k = blockIdx.x * TILE_K; + int block_n = blockIdx.y * TILE_N; + + // Phase 1: Load and quantize into shared memory + // 256 threads, each handles 8 elements (64*32/256 = 8) + // Thread layout: 32 threads in N, 8 threads in K + int tid = threadIdx.x; + int tn = tid % 32; // 0-31 + int tk = tid / 32; // 0-7 + + #pragma unroll + for (int ki = 0; ki < 8; ki++) { + int k = block_k + tk * 8 + ki; + int n = block_n + tn; + + if (k < K && n < N) { + nv_bfloat16 bf = input[k * N + n]; + smem_q[tk * 8 + ki][tn] = bf16_to_nvf4_e2m1(__bfloat162float(bf)); + } else { + smem_q[tk * 8 + ki][tn] = 0; + } + } + + __syncthreads(); + + // Phase 2: Write transposed and packed (8 NVF4 = 32 bits per write) + // Each thread writes 4 bytes (8 k-values) for one n + // 256 threads handle 32 n-values x 8 k-groups = 256 outputs + int out_n = block_n + (tid % 32); + int out_k_group = tid / 32; // 0-7, each group is 8 k-values + + int k_base = out_k_group * 8; + int num_k_pairs = K / 2; + + if (out_n < N && (block_k + k_base + 7) < K) { + // Fast path: full 8 k-values, vectorized uint32 write + uint8_t q0 = smem_q[k_base + 0][tn]; + uint8_t q1 = smem_q[k_base + 1][tn]; + uint8_t q2 = smem_q[k_base + 2][tn]; + uint8_t q3 = smem_q[k_base + 3][tn]; + uint8_t q4 = smem_q[k_base + 4][tn]; + uint8_t q5 = smem_q[k_base + 5][tn]; + uint8_t q6 = smem_q[k_base + 6][tn]; + uint8_t q7 = smem_q[k_base + 7][tn]; + + uint32_t packed = ((q1 << 4) | (q0 & 0x0F)) + | (((q3 << 4) | (q2 & 0x0F)) << 8) + | (((q5 << 4) | (q4 & 0x0F)) << 16) + | (((q7 << 4) | (q6 & 0x0F)) << 24); + + // Output: ColMajor [N, K] packed - 4 consecutive bytes for 8 k-values + int byte_offset = out_n * num_k_pairs + (block_k + k_base) / 2; + *reinterpret_cast(&output[byte_offset]) = packed; + } else if (out_n < N) { + // Edge case: partial k-group, scalar writes + for (int i = 0; i < 8 && (block_k + k_base + i + 1) < K; i += 2) { + uint8_t q0 = smem_q[k_base + i][tn]; + uint8_t q1 = smem_q[k_base + i + 1][tn]; + output[out_n * num_k_pairs + (block_k + k_base + i) / 2] = (q1 << 4) | (q0 & 0x0F); + } + } +} + +// Initialize scale factors to 1.0 (UE4M3 encoding: 0x38) +__global__ void init_scale_factors_kernel( + uint8_t* __restrict__ sf, + int count +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= count) return; + sf[idx] = 0x38; // float_ue4m3_t(1.0f) = 0x38 +} + +// ============================================================================ +// Host-side BF16 -> NVF4 Quantization Helpers +// ============================================================================ + +// Convert float to float_e2m1_t (NVF4 4-bit format) +inline cutlass::float_e2m1_t float_to_e2m1(float val) { + // E2M1 representable values: 0, 0.5, 1, 1.5, 2, 3, 4, 6 (and negatives) + // Clamp to representable range + val = std::max(-6.0f, std::min(6.0f, val)); + return cutlass::float_e2m1_t(val); +} + +// Convert float to float_ue4m3_t (scale factor, unsigned 8-bit) +inline cutlass::float_ue4m3_t float_to_ue4m3(float val) { + // UE4M3 range: approximately [2^-9, 448] + val = std::max(1.0f/512.0f, std::min(448.0f, val)); + return cutlass::float_ue4m3_t(val); +} + +// Quantize a block of floats to NVF4 with a computed scale factor +// Returns the scale factor used +inline float quantize_block_to_e2m1( + const float* input, + cutlass::float_e2m1_t* output, + int count +) { + // Find max absolute value in block + float max_abs = 0.0f; + for (int i = 0; i < count; ++i) { + max_abs = std::max(max_abs, std::abs(input[i])); + } + + // Compute scale factor: scale * 6.0 >= max_abs + // So scale = max_abs / 6.0 (6.0 is max representable in E2M1) + float scale = (max_abs > 1e-8f) ? (max_abs / 6.0f) : 1.0f; + float inv_scale = 1.0f / scale; + + // Quantize each element + for (int i = 0; i < count; ++i) { + float scaled_val = input[i] * inv_scale; + output[i] = float_to_e2m1(scaled_val); + } + + return scale; +} + +// ============================================================================ +// NVF4 GEMM Entry Point (BF16 I/O) +// ============================================================================ + +cudaError_t gemm_nvf4_bf16( + const nv_bfloat16* A, // [M, K] BF16 input (device) + const nv_bfloat16* B, // [K, N] BF16 input (device) + nv_bfloat16* D, // [M, N] BF16 output (device) + int M, int N, int K, + float alpha, + float beta, + cudaStream_t stream +) { + // For SFA and SFB tensors layouts + using Sm1xxBlkScaledConfigLocal = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + // Build strides and layouts + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1)); + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1)); + StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, 1)); + StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, 1)); + + auto problem_shape = cute::make_shape(M, N, K, 1); + LayoutSFA layout_SFA = Sm1xxBlkScaledConfigLocal::tile_atom_to_shape_SFA(problem_shape); + LayoutSFB layout_SFB = Sm1xxBlkScaledConfigLocal::tile_atom_to_shape_SFB(problem_shape); + + // Compute sizes + int64_t size_A = static_cast(M) * K; + int64_t size_B = static_cast(K) * N; + int64_t size_C = static_cast(M) * N; + int64_t size_D = size_C; + + size_t sfa_size = cute::size(cute::filter_zeros(layout_SFA)); + size_t sfb_size = cute::size(cute::filter_zeros(layout_SFB)); + + // WORKAROUND: Blackwell driver TMA bug requires >= 128KB allocations + constexpr size_t MIN_ALLOC_128KB = 128 * 1024; + size_t min_sf_elements = MIN_ALLOC_128KB / sizeof(ScaleFactorType); + + size_t sfa_padded = std::max(sfa_size, min_sf_elements); + size_t sfb_padded = std::max(sfb_size, min_sf_elements); + + // Allocate device memory directly (no host memory needed!) + // NVF4 packed: 2 elements per byte + size_t size_A_packed = (size_A + 1) / 2; // Packed bytes for A + size_t size_B_packed = (size_B + 1) / 2; // Packed bytes for B + + cutlass::device_memory::allocation dev_A(size_A_packed); + cutlass::device_memory::allocation dev_B(size_B_packed); + cutlass::device_memory::allocation dev_SFA(sfa_padded); + cutlass::device_memory::allocation dev_SFB(sfb_padded); + cutlass::device_memory::allocation dev_C(size_C); + // D is used directly - no intermediate allocation needed + + cudaError_t err; + + // Create second stream for parallel quantization + cudaStream_t stream_b; + err = cudaStreamCreate(&stream_b); + if (err != cudaSuccess) return err; + + // Initialize C to zero (on main stream) + err = cudaMemsetAsync(dev_C.get(), 0, size_C * sizeof(ElementC), stream); + if (err != cudaSuccess) { cudaStreamDestroy(stream_b); return err; } + + // ========================================================================= + // GPU-side quantization: BF16 -> NVF4 (PARALLEL on 2 streams!) + // Stream A: quantize_A + init_scale_A + // Stream B: quantize_B + init_scale_B + // ========================================================================= + + constexpr int BLOCK_SIZE = 256; + + // Stream A: Quantize A + scale factors + { + int total_quads = (M * K) / 8; + int grid_size = (total_quads + BLOCK_SIZE - 1) / BLOCK_SIZE; + quantize_A_gpu_kernel<<>>( + A, dev_A.get(), M, K + ); + int grid_sfa = (sfa_padded + BLOCK_SIZE - 1) / BLOCK_SIZE; + init_scale_factors_kernel<<>>( + dev_SFA.get(), static_cast(sfa_padded) + ); + } + + // Stream B: Quantize B + scale factors (PARALLEL with stream A) + { + constexpr int TILE_K = 64; + constexpr int TILE_N = 32; + constexpr int B_BLOCK_SIZE = 256; + dim3 grid((K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N); + quantize_B_gpu_kernel<<>>( + B, dev_B.get(), K, N + ); + int grid_sfb = (sfb_padded + BLOCK_SIZE - 1) / BLOCK_SIZE; + init_scale_factors_kernel<<>>( + dev_SFB.get(), static_cast(sfb_padded) + ); + } + + // Wait for both streams to complete + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) { cudaStreamDestroy(stream_b); return err; } + err = cudaStreamSynchronize(stream_b); + cudaStreamDestroy(stream_b); + if (err != cudaSuccess) return err; + + // Build GEMM arguments - write directly to user buffer D + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + { // Mainloop arguments + reinterpret_cast(dev_A.get()), stride_A, + reinterpret_cast(dev_B.get()), stride_B, + reinterpret_cast(dev_SFA.get()), layout_SFA, + reinterpret_cast(dev_SFB.get()), layout_SFB + }, + { // Epilogue arguments - output directly to D + {alpha, beta}, + dev_C.get(), stride_C, + reinterpret_cast(D), stride_D + } + }; + + // Run GEMM + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[NVF4 GEMM] can_implement failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[NVF4 GEMM] initialize failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + status = gemm_op.run(stream); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[NVF4 GEMM] run failed: %d\n", static_cast(status)); + return cudaErrorLaunchFailure; + } + + // CUTLASS writes directly to D - no copy needed + return cudaSuccess; +} + +bool is_available() { + int device_id = 0; + cudaGetDevice(&device_id); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device_id); + return (props.major == 12 && (props.minor == 0 || props.minor == 1)); +} + +} // namespace nvf4_bf16_gemm_sm120 +} // namespace ops +} // namespace pygpukit + +// Extern C for linking +extern "C" { + cudaError_t pygpukit_gemm_nvf4_bf16_sm120( + const nv_bfloat16* A, const nv_bfloat16* B, nv_bfloat16* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return pygpukit::ops::nvf4_bf16_gemm_sm120::gemm_nvf4_bf16(A, B, D, M, N, K, alpha, beta, stream); + } + + bool pygpukit_nvf4_bf16_sm120_available() { + return pygpukit::ops::nvf4_bf16_gemm_sm120::is_available(); + } +} + +#else // !SM120 + +namespace pygpukit { +namespace ops { +namespace nvf4_bf16_gemm_sm120 { + +cudaError_t gemm_nvf4_bf16( + const nv_bfloat16* A, const nv_bfloat16* B, nv_bfloat16* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream +) { + return cudaErrorNotSupported; +} + +bool is_available() { + return false; +} + +} // namespace nvf4_bf16_gemm_sm120 +} // namespace ops +} // namespace pygpukit + +extern "C" { + cudaError_t pygpukit_gemm_nvf4_bf16_sm120( + const nv_bfloat16* A, const nv_bfloat16* B, nv_bfloat16* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return cudaErrorNotSupported; + } + + bool pygpukit_nvf4_bf16_sm120_available() { + return false; + } +} + +#endif diff --git a/native/ops/matmul/matmul_nvf4_nvf4_sm120.cu b/native/ops/matmul/matmul_nvf4_nvf4_sm120.cu new file mode 100644 index 0000000..09284ad --- /dev/null +++ b/native/ops/matmul/matmul_nvf4_nvf4_sm120.cu @@ -0,0 +1,460 @@ +/** + * NVF4 GEMM implementation for SM120 (Blackwell GeForce) - Pure NVF4 I/O + * + * Based on CUTLASS example 79a: blackwell_geforce_nvfp4_bf16_gemm + * + * This version takes pre-quantized NVF4 inputs directly to measure + * pure GEMM kernel performance without quantization overhead. + * + * Data Flow: + * NVF4 input (packed) + Scale Factors -> CUTLASS GEMM -> BF16 output + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +// Enable NVF4 SM120 +#define PYGPUKIT_ENABLE_NVF4_SM120 + +// Only compile for SM120+ +#if (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) && defined(PYGPUKIT_ENABLE_NVF4_SM120) + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/device_memory.h" + +using namespace cute; + +namespace pygpukit { +namespace ops { +namespace nvf4_nvf4_gemm_sm120 { + +// ============================================================================ +// GEMM Configuration (from example 79a) +// ============================================================================ + +// A matrix configuration +using ElementA = cutlass::nv_float4_t; // NVF4 wrapper type +using LayoutATag = cutlass::layout::RowMajor; +constexpr int AlignmentA = 32; // Memory access granularity + +// B matrix configuration +using ElementB = cutlass::nv_float4_t; // NVF4 wrapper type +using LayoutBTag = cutlass::layout::ColumnMajor; +constexpr int AlignmentB = 32; + +// C/D matrix configuration (BF16 output) +using ElementC = cutlass::bfloat16_t; +using ElementD = cutlass::bfloat16_t; +using LayoutCTag = cutlass::layout::RowMajor; +using LayoutDTag = cutlass::layout::RowMajor; +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // 8 +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // 8 + +// Kernel config +using ElementAccumulator = float; +using ArchTag = cutlass::arch::Sm120; +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + +// Tile shapes - 128x128x128 (baseline, optimal for SM120) +using ThreadBlockShape = Shape<_128, _128, _128>; +using ClusterShape = Shape<_1, _1, _1>; // GeForce: no cluster support + +// Epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto +>::CollectiveOp; + +// Mainloop - Pingpong schedule with 3-stage pipeline (optimal for SM120) +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCount<3>, // 3 stages optimal (2=base, 4=too much smem) + cutlass::gemm::KernelTmaWarpSpecializedPingpong +>::CollectiveOp; + +// GEMM Kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Types for data layout +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; +using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + +// Data types for raw storage +using DataTypeA = typename ElementA::DataType; // float_e2m1_t +using ScaleFactorType = typename ElementA::ScaleFactorType; // float_ue4m3_t + +// ============================================================================ +// NVF4 GEMM Entry Point (Pre-quantized NVF4 I/O) +// ============================================================================ + +cudaError_t gemm_nvf4_nvf4( + const uint8_t* A_packed, // [M, K] NVF4 packed (M*K/2 bytes), RowMajor + const uint8_t* B_packed, // [N, K] NVF4 packed (N*K/2 bytes), ColMajor + const uint8_t* SFA, // Scale factors for A + const uint8_t* SFB, // Scale factors for B + nv_bfloat16* D, // [M, N] BF16 output (device) + int M, int N, int K, + float alpha, + float beta, + cudaStream_t stream +) { + // For SFA and SFB tensors layouts + using Sm1xxBlkScaledConfigLocal = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + // Build strides and layouts + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1)); + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1)); + StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, 1)); + StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, 1)); + + auto problem_shape = cute::make_shape(M, N, K, 1); + LayoutSFA layout_SFA = Sm1xxBlkScaledConfigLocal::tile_atom_to_shape_SFA(problem_shape); + LayoutSFB layout_SFB = Sm1xxBlkScaledConfigLocal::tile_atom_to_shape_SFB(problem_shape); + + // Compute sizes + int64_t size_C = static_cast(M) * N; + int64_t size_D = size_C; + + // Allocate output buffers + cutlass::device_memory::allocation dev_C(size_C); + cutlass::device_memory::allocation dev_D_out(size_D); + + cudaError_t err; + + // Initialize C to zero + err = cudaMemsetAsync(dev_C.get(), 0, size_C * sizeof(ElementC), stream); + if (err != cudaSuccess) return err; + + // Build GEMM arguments using pre-quantized device memory + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + { // Mainloop arguments + reinterpret_cast(A_packed), stride_A, + reinterpret_cast(B_packed), stride_B, + reinterpret_cast(SFA), layout_SFA, + reinterpret_cast(SFB), layout_SFB + }, + { // Epilogue arguments + {alpha, beta}, + dev_C.get(), stride_C, + dev_D_out.get(), stride_D + } + }; + + // Run GEMM + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[NVF4 GEMM] can_implement failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[NVF4 GEMM] initialize failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + status = gemm_op.run(stream); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[NVF4 GEMM] run failed: %d\n", static_cast(status)); + return cudaErrorLaunchFailure; + } + + // Copy result from CUTLASS output buffer to user-provided D buffer (D2D only!) + err = cudaMemcpyAsync(D, dev_D_out.get(), + size_D * sizeof(nv_bfloat16), + cudaMemcpyDeviceToDevice, stream); + if (err != cudaSuccess) { + return err; + } + + return cudaSuccess; +} + +// ============================================================================ +// Benchmark helper: prepare pre-quantized data and run GEMM +// ============================================================================ + +// Initialize scale factors to 1.0 (UE4M3 encoding: 0x38) +__global__ void init_scale_factors_kernel(uint8_t* sf, int count) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= count) return; + sf[idx] = 0x38; // float_ue4m3_t(1.0f) = 0x38 +} + +// Initialize NVF4 data to 1.0 (E2M1 encoding: 0x22 = two 1.0 values packed) +__global__ void init_nvf4_ones_kernel(uint8_t* data, int count) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= count) return; + // E2M1 1.0 = 0x2, packed: low nibble = 0x2, high nibble = 0x2 -> 0x22 + data[idx] = 0x22; +} + +// Benchmark entry point: allocates, initializes, and runs GEMM (all inline) +cudaError_t benchmark_gemm_nvf4( + nv_bfloat16* D, // [M, N] BF16 output (device, pre-allocated) + int M, int N, int K, + float alpha, + float beta, + cudaStream_t stream +) { + using Sm1xxBlkScaledConfigLocal = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + // Build strides and layouts + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1)); + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1)); + StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, 1)); + StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, 1)); + + auto problem_shape = cute::make_shape(M, N, K, 1); + LayoutSFA layout_SFA = Sm1xxBlkScaledConfigLocal::tile_atom_to_shape_SFA(problem_shape); + LayoutSFB layout_SFB = Sm1xxBlkScaledConfigLocal::tile_atom_to_shape_SFB(problem_shape); + + // Compute sizes + int64_t size_A = static_cast(M) * K; + int64_t size_B = static_cast(K) * N; + int64_t size_C = static_cast(M) * N; + int64_t size_D = size_C; + + size_t sfa_size = cute::size(cute::filter_zeros(layout_SFA)); + size_t sfb_size = cute::size(cute::filter_zeros(layout_SFB)); + + // WORKAROUND: Blackwell driver TMA bug requires >= 128KB allocations + constexpr size_t MIN_ALLOC_128KB = 128 * 1024; + size_t min_sf_elements = MIN_ALLOC_128KB / sizeof(ScaleFactorType); + + size_t sfa_padded = std::max(sfa_size, min_sf_elements); + size_t sfb_padded = std::max(sfb_size, min_sf_elements); + + // NVF4 packed sizes (with 128KB minimum) + size_t size_A_packed = (size_A + 1) / 2; + size_t size_B_packed = (size_B + 1) / 2; + size_t size_A_padded = std::max(size_A_packed, MIN_ALLOC_128KB); + size_t size_B_padded = std::max(size_B_packed, MIN_ALLOC_128KB); + + // Allocate device memory (no need to allocate D - use user buffer directly) + cutlass::device_memory::allocation dev_A(size_A_padded); + cutlass::device_memory::allocation dev_B(size_B_padded); + cutlass::device_memory::allocation dev_SFA(sfa_padded); + cutlass::device_memory::allocation dev_SFB(sfb_padded); + cutlass::device_memory::allocation dev_C(size_C); + + cudaError_t err; + + // Initialize C to zero + err = cudaMemsetAsync(dev_C.get(), 0, size_C * sizeof(ElementC), stream); + if (err != cudaSuccess) return err; + + constexpr int BLOCK_SIZE = 256; + + // Initialize A and B to 1.0 + { + int grid_a = (size_A_padded + BLOCK_SIZE - 1) / BLOCK_SIZE; + int grid_b = (size_B_padded + BLOCK_SIZE - 1) / BLOCK_SIZE; + init_nvf4_ones_kernel<<>>(dev_A.get(), size_A_padded); + init_nvf4_ones_kernel<<>>(dev_B.get(), size_B_padded); + } + + // Initialize scale factors to 1.0 + { + int grid_sfa = (sfa_padded + BLOCK_SIZE - 1) / BLOCK_SIZE; + int grid_sfb = (sfb_padded + BLOCK_SIZE - 1) / BLOCK_SIZE; + init_scale_factors_kernel<<>>(dev_SFA.get(), sfa_padded); + init_scale_factors_kernel<<>>(dev_SFB.get(), sfb_padded); + } + + // Sync before GEMM + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) return err; + + // Build GEMM arguments - use D directly (no intermediate buffer) + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + { // Mainloop arguments + reinterpret_cast(dev_A.get()), stride_A, + reinterpret_cast(dev_B.get()), stride_B, + reinterpret_cast(dev_SFA.get()), layout_SFA, + reinterpret_cast(dev_SFB.get()), layout_SFB + }, + { // Epilogue arguments - write directly to user buffer + {alpha, beta}, + dev_C.get(), stride_C, + reinterpret_cast(D), stride_D + } + }; + + // Run GEMM + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[NVF4 Bench] can_implement failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[NVF4 Bench] initialize failed: %d\n", static_cast(status)); + return cudaErrorInvalidValue; + } + + status = gemm_op.run(stream); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "[NVF4 Bench] run failed: %d\n", static_cast(status)); + return cudaErrorLaunchFailure; + } + + // No D2D copy needed - CUTLASS writes directly to user buffer D + return cudaSuccess; +} + +bool is_available() { + int device_id = 0; + cudaGetDevice(&device_id); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device_id); + return (props.major == 12 && (props.minor == 0 || props.minor == 1)); +} + +} // namespace nvf4_nvf4_gemm_sm120 +} // namespace ops +} // namespace pygpukit + +// Extern C for linking +extern "C" { + cudaError_t pygpukit_gemm_nvf4_nvf4_sm120( + const uint8_t* A_packed, const uint8_t* B_packed, + const uint8_t* SFA, const uint8_t* SFB, + nv_bfloat16* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return pygpukit::ops::nvf4_nvf4_gemm_sm120::gemm_nvf4_nvf4( + A_packed, B_packed, SFA, SFB, D, M, N, K, alpha, beta, stream + ); + } + + cudaError_t pygpukit_benchmark_gemm_nvf4_sm120( + nv_bfloat16* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return pygpukit::ops::nvf4_nvf4_gemm_sm120::benchmark_gemm_nvf4( + D, M, N, K, alpha, beta, stream + ); + } + + bool pygpukit_nvf4_nvf4_sm120_available() { + return pygpukit::ops::nvf4_nvf4_gemm_sm120::is_available(); + } +} + +#else // !SM120 + +namespace pygpukit { +namespace ops { +namespace nvf4_nvf4_gemm_sm120 { + +cudaError_t gemm_nvf4_nvf4( + const uint8_t* A_packed, const uint8_t* B_packed, + const uint8_t* SFA, const uint8_t* SFB, + nv_bfloat16* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream +) { + return cudaErrorNotSupported; +} + +cudaError_t benchmark_gemm_nvf4( + nv_bfloat16* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream +) { + return cudaErrorNotSupported; +} + +bool is_available() { + return false; +} + +} // namespace nvf4_nvf4_gemm_sm120 +} // namespace ops +} // namespace pygpukit + +extern "C" { + cudaError_t pygpukit_gemm_nvf4_nvf4_sm120( + const uint8_t* A_packed, const uint8_t* B_packed, + const uint8_t* SFA, const uint8_t* SFB, + nv_bfloat16* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return cudaErrorNotSupported; + } + + cudaError_t pygpukit_benchmark_gemm_nvf4_sm120( + nv_bfloat16* D, + int M, int N, int K, + float alpha, float beta, + cudaStream_t stream + ) { + return cudaErrorNotSupported; + } + + bool pygpukit_nvf4_nvf4_sm120_available() { + return false; + } +} + +#endif diff --git a/native/ops/matmul_cutlass.cuh b/native/ops/matmul_cutlass.cuh index a4e85cb..667c1ce 100644 --- a/native/ops/matmul_cutlass.cuh +++ b/native/ops/matmul_cutlass.cuh @@ -35,6 +35,7 @@ #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_batched.h" #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/epilogue/thread/linear_combination_gelu.h" #include "cutlass/util/device_memory.h" @@ -84,9 +85,15 @@ inline int get_cached_sm_version() { // Minimum supported SM version constexpr int MIN_SM_VERSION = 80; -// Check if SM version is supported +// Check if SM version is supported for CUTLASS kernels +// Note: SM 120 (Blackwell GeForce) can use CUTLASS 2.x kernels (SM80 ArchTag) +// as a fallback since Blackwell supports all Ampere instructions. +// CUTLASS 4.x native SM120 kernels only support FP8, so we use SM80 path. inline bool is_sm_supported() { - return get_cached_sm_version() >= MIN_SM_VERSION; + int sm = get_cached_sm_version(); + // SM 80+: CUTLASS 2.x/3.x kernels work + // SM 120: Uses CUTLASS 2.x (SM80 ArchTag) as fallback + return sm >= MIN_SM_VERSION; } // SM version classification for kernel selection @@ -189,6 +196,34 @@ using TF32Gemm_Sm89 = cutlass::gemm::device::Gemm< // Default alias (SM80 for backward compatibility) using TF32Gemm = TF32Gemm_Sm80; +// ============================================================================ +// TF32 Batched GEMM (FP32 input/output, TF32 TensorCore for batch operations) +// ============================================================================ + +// SM86 (RTX 30xx): 5-stage pipeline for batched operations +using TF32GemmBatched_Sm86 = cutlass::gemm::device::GemmBatched< + float, // ElementA (will be B^T) + cutlass::layout::ColumnMajor, // LayoutA + float, // ElementB (will be A^T) + cutlass::layout::ColumnMajor, // LayoutB + float, // ElementC (will be C^T) + cutlass::layout::ColumnMajor, // LayoutC + float, // ElementAccumulator + cutlass::arch::OpClassTensorOp, // OperatorClass (TensorCore) + cutlass::arch::Sm80, // ArchTag (Ampere TensorCore compatible) + cutlass::gemm::GemmShape<128, 128, 16>, // ThreadBlockShape + cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape + cutlass::gemm::GemmShape<16, 8, 8>, // InstructionShape (mma.sync) + cutlass::epilogue::thread::LinearCombination< + float, 128 / cutlass::sizeof_bits::value, + float, float>, // EpilogueOp (128-bit aligned) + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 5 // Stages (5-stage for SM86) +>; + +// Default batched alias +using TF32GemmBatched = TF32GemmBatched_Sm86; + // ============================================================================ // FP16 GEMM (FP16 input/output, FP16 TensorCore) // ============================================================================ @@ -589,37 +624,39 @@ inline cudaError_t gemm_tf32( // Runtime SM dispatch with tiered kernel selection int sm_tier = get_sm_tier(); - // NOTE: SM120 CUTLASS 4.x kernels are DISABLED (FP8 only). - // SM100 (B200) supports FP32/FP16/BF16. + // SM120 (Blackwell GeForce): Use CUTLASS 2.x (SM86) as fallback + // CUTLASS 4.x native SM120 kernels only support FP8, not FP32/FP16/BF16 + // SM100/SM90 kernels also don't work on SM120 (different tensor core gen) - // SM100+ (Blackwell datacenter: B200) - CUTLASS 4.x with 2SM MMA + // SM100 (Blackwell datacenter: B200 only, NOT SM120) #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - if (sm_tier >= 100) { + if (sm_tier >= 100 && sm_tier < 120) { return cutlass_gemm_sm100::gemm_tf32_sm100(A, B, C, M, N, K, alpha, beta, stream); } #endif - // SM90+ (Hopper: H100) - CUTLASS 3.x with WGMMA/TMA + // SM90-99 (Hopper: H100 only) #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - if (sm_tier >= 90) { + if (sm_tier >= 90 && sm_tier < 100) { return cutlass_gemm_sm90::gemm_tf32_sm90(A, B, C, M, N, K, alpha, beta, stream); } #endif - // Fallback to CUTLASS 2.x API for SM80-89 (and SM120 until FP8 support) + // CUTLASS 2.x API for SM80-89 AND SM120+ (Blackwell GeForce fallback) // Transpose trick: C^T (NxM col) = B^T (NxK col) @ A^T (KxM col) cutlass::gemm::GemmCoord problem_size(N, M, K); - if (sm_tier >= 89) { - // SM89 (Ada): 6-stage pipeline with larger tiles - return run_gemm( + // SM120+ uses SM86 kernel (5-stage, works on Blackwell) + if (sm_tier >= 120 || sm_tier == 89) { + // SM120 (Blackwell GeForce) / SM89 (Ada): Use SM86 5-stage for stability + return run_gemm( problem_size, B, N, A, K, C, N, C, N, alpha, beta, stream); } else if (sm_tier >= 86) { - // SM86 (Ampere consumer): 5-stage pipeline + // SM86-88 (Ampere consumer): 5-stage pipeline return run_gemm( problem_size, B, N, A, K, C, N, C, N, alpha, beta, stream); } else { - // SM80 (Ampere datacenter): 4-stage pipeline + // SM80-85 (Ampere datacenter): 4-stage pipeline return run_gemm( problem_size, B, N, A, K, C, N, C, N, alpha, beta, stream); } @@ -641,36 +678,33 @@ inline cudaError_t gemm_fp16( // Runtime SM dispatch with tiered kernel selection int sm_tier = get_sm_tier(); - // NOTE: SM120 CUTLASS 4.x kernels are DISABLED (FP8 only). - - // SM100+ (Blackwell datacenter: B200) + // SM100 (Blackwell datacenter: B200 only, NOT SM120) #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - if (sm_tier >= 100) { + if (sm_tier >= 100 && sm_tier < 120) { return cutlass_gemm_sm100::gemm_fp16_sm100(A, B, C, M, N, K, alpha, beta, stream); } #endif - // SM90+ (Hopper: H100) + // SM90-99 (Hopper: H100 only) #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - if (sm_tier >= 90) { + if (sm_tier >= 90 && sm_tier < 100) { return cutlass_gemm_sm90::gemm_fp16_sm90(A, B, C, M, N, K, alpha, beta, stream); } #endif - // Fallback to CUTLASS 2.x API for SM80-89 (and SM120 until FP8 support) - // Transpose trick: C^T = B^T @ A^T + // CUTLASS 2.x API for SM80-89 AND SM120+ (Blackwell GeForce fallback) cutlass::gemm::GemmCoord problem_size(N, M, K); - if (sm_tier >= 89) { - // SM89 (Ada): 6-stage pipeline with larger tiles - return run_gemm( + if (sm_tier >= 120 || sm_tier == 89) { + // SM120 (Blackwell GeForce) / SM89 (Ada): Use SM86 5-stage + return run_gemm( problem_size, B, N, A, K, C, N, C, N, alpha, beta, stream); } else if (sm_tier >= 86) { - // SM86 (Ampere consumer): 5-stage pipeline + // SM86-88 (Ampere consumer): 5-stage pipeline return run_gemm( problem_size, B, N, A, K, C, N, C, N, alpha, beta, stream); } else { - // SM80 (Ampere datacenter): 4-stage pipeline + // SM80-85 (Ampere datacenter): 4-stage pipeline return run_gemm( problem_size, B, N, A, K, C, N, C, N, alpha, beta, stream); } @@ -692,36 +726,33 @@ inline cudaError_t gemm_bf16( // Runtime SM dispatch with tiered kernel selection int sm_tier = get_sm_tier(); - // NOTE: SM120 CUTLASS 4.x kernels are DISABLED (FP8 only). - - // SM100+ (Blackwell datacenter: B200) + // SM100 (Blackwell datacenter: B200 only, NOT SM120) #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - if (sm_tier >= 100) { + if (sm_tier >= 100 && sm_tier < 120) { return cutlass_gemm_sm100::gemm_bf16_sm100(A, B, C, M, N, K, alpha, beta, stream); } #endif - // SM90+ (Hopper: H100) + // SM90-99 (Hopper: H100 only) #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - if (sm_tier >= 90) { + if (sm_tier >= 90 && sm_tier < 100) { return cutlass_gemm_sm90::gemm_bf16_sm90(A, B, C, M, N, K, alpha, beta, stream); } #endif - // Fallback to CUTLASS 2.x API for SM80-89 (and SM120 until FP8 support) - // Transpose trick: C^T = B^T @ A^T + // CUTLASS 2.x API for SM80-89 AND SM120+ (Blackwell GeForce fallback) cutlass::gemm::GemmCoord problem_size(N, M, K); - if (sm_tier >= 89) { - // SM89 (Ada): 6-stage pipeline with larger tiles - return run_gemm( + if (sm_tier >= 120 || sm_tier == 89) { + // SM120 (Blackwell GeForce) / SM89 (Ada): Use SM86 5-stage + return run_gemm( problem_size, B, N, A, K, C, N, C, N, alpha, beta, stream); } else if (sm_tier >= 86) { - // SM86 (Ampere consumer): 5-stage pipeline + // SM86-88 (Ampere consumer): 5-stage pipeline return run_gemm( problem_size, B, N, A, K, C, N, C, N, alpha, beta, stream); } else { - // SM80 (Ampere datacenter): 4-stage pipeline + // SM80-85 (Ampere datacenter): 4-stage pipeline return run_gemm( problem_size, B, N, A, K, C, N, C, N, alpha, beta, stream); } @@ -858,6 +889,107 @@ inline cudaError_t gemm_bf16_bias_gelu( } } +// ============================================================================ +// Batched GEMM Implementation +// ============================================================================ + +/** + * Template helper for batched GEMM dispatch + * + * Memory layout for strided batched GEMM: + * - A[batch, M, K] row-major: stride_A = M * K + * - B[batch, K, N] row-major: stride_B = K * N + * - C[batch, M, N] row-major: stride_C = M * N + * + * Using the transpose trick for CUTLASS column-major kernels: + * - C^T[batch, N, M] = B^T[batch, N, K] @ A^T[batch, K, M] + */ +template +inline cudaError_t run_gemm_batched( + cutlass::gemm::GemmCoord problem_size, + const void* A, int ldA, int64_t strideA, + const void* B, int ldB, int64_t strideB, + void* C, int ldC, int64_t strideC, + float alpha, float beta, + int batch_count, + cudaStream_t stream +) { + using ElementA = typename GemmBatchedOp::ElementA; + using ElementB = typename GemmBatchedOp::ElementB; + using ElementC = typename GemmBatchedOp::ElementC; + + typename GemmBatchedOp::Arguments arguments{ + problem_size, + {static_cast(A), ldA}, + strideA, + {static_cast(B), ldB}, + strideB, + {static_cast(C), ldC}, + strideC, + {static_cast(C), ldC}, + strideC, + {alpha, beta}, + batch_count + }; + + GemmBatchedOp gemm_op; + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return cudaErrorInvalidValue; + } + + size_t workspace_size = GemmBatchedOp::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get(), stream); + if (status != cutlass::Status::kSuccess) { + return cudaErrorInvalidValue; + } + + status = gemm_op(stream); + if (status != cutlass::Status::kSuccess) { + return cudaErrorInvalidValue; + } + + return cudaSuccess; +} + +/** + * FP32 Strided Batched GEMM using CUTLASS TensorCore (TF32) + * + * Computes: C[b] = A[b] @ B[b] for b in [0, batch_count) + * Where A[batch, M, K], B[batch, K, N], C[batch, M, N] are row-major. + */ +inline cudaError_t gemm_batched_fp32( + const float* A, + const float* B, + float* C, + int M, int N, int K, + int batch_count, + int64_t strideA, + int64_t strideB, + int64_t strideC, + float alpha = 1.0f, + float beta = 0.0f, + cudaStream_t stream = nullptr +) { + // Transpose trick: C^T[N,M] = B^T[N,K] @ A^T[K,M] + // For batched: each batch element uses the same transformation + cutlass::gemm::GemmCoord problem_size(N, M, K); + + // Note: Strides remain the same (element count between batches) + // but the roles of A/B are swapped for the transpose trick + return run_gemm_batched( + problem_size, + B, N, strideB, // B^T as first operand (ld = N) + A, K, strideA, // A^T as second operand (ld = K) + C, N, strideC, // C^T as output (ld = N) + alpha, beta, + batch_count, + stream + ); +} + // ============================================================================ // Dispatch function for runtime dtype selection // ============================================================================ diff --git a/native/ops/nn/activation_kernels.cuh b/native/ops/nn/activation_kernels.cuh index a569f06..a27e15f 100644 --- a/native/ops/nn/activation_kernels.cuh +++ b/native/ops/nn/activation_kernels.cuh @@ -119,6 +119,109 @@ __global__ void silu_bf16_kernel(const __nv_bfloat16* __restrict__ input, } } +// ============================================================================ +// ReLU Activation: max(0, x) +// ============================================================================ + +__global__ void relu_f32_kernel(const float* __restrict__ input, + float* __restrict__ output, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = fmaxf(0.0f, input[idx]); + } +} + +__global__ void relu_f16_kernel(const __half* __restrict__ input, + __half* __restrict__ output, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __half2float(input[idx]); + output[idx] = __float2half(fmaxf(0.0f, x)); + } +} + +__global__ void relu_bf16_kernel(const __nv_bfloat16* __restrict__ input, + __nv_bfloat16* __restrict__ output, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __bfloat162float(input[idx]); + output[idx] = __float2bfloat16(fmaxf(0.0f, x)); + } +} + +// ============================================================================ +// Sigmoid Activation: 1 / (1 + exp(-x)) +// ============================================================================ + +__device__ __forceinline__ float sigmoid_f32(float x) { + return 1.0f / (1.0f + expf(-x)); +} + +__global__ void sigmoid_f32_kernel(const float* __restrict__ input, + float* __restrict__ output, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = sigmoid_f32(input[idx]); + } +} + +__global__ void sigmoid_f16_kernel(const __half* __restrict__ input, + __half* __restrict__ output, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __half2float(input[idx]); + output[idx] = __float2half(sigmoid_f32(x)); + } +} + +__global__ void sigmoid_bf16_kernel(const __nv_bfloat16* __restrict__ input, + __nv_bfloat16* __restrict__ output, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __bfloat162float(input[idx]); + output[idx] = __float2bfloat16(sigmoid_f32(x)); + } +} + +// ============================================================================ +// Tanh Activation +// ============================================================================ + +__global__ void tanh_f32_kernel(const float* __restrict__ input, + float* __restrict__ output, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = tanhf(input[idx]); + } +} + +__global__ void tanh_f16_kernel(const __half* __restrict__ input, + __half* __restrict__ output, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __half2float(input[idx]); + output[idx] = __float2half(tanhf(x)); + } +} + +__global__ void tanh_bf16_kernel(const __nv_bfloat16* __restrict__ input, + __nv_bfloat16* __restrict__ output, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __bfloat162float(input[idx]); + output[idx] = __float2bfloat16(tanhf(x)); + } +} + } // namespace nn } // namespace ops } // namespace pygpukit diff --git a/native/ops/nn/memory_kernels.cuh b/native/ops/nn/memory_kernels.cuh index 0299f6e..0bf1353 100644 --- a/native/ops/nn/memory_kernels.cuh +++ b/native/ops/nn/memory_kernels.cuh @@ -349,6 +349,234 @@ __global__ void transpose_021_bf16_kernel( } } +// ============================================================================ +// 3D Transpose: [d0, d1, d2] -> [d0, d2, d1] +// Swaps last two axes (common in attention) +// ============================================================================ + +__global__ void transpose_012_f32_kernel( + const float* __restrict__ src, + float* __restrict__ dst, + size_t dim0, + size_t dim1, + size_t dim2 +) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = dim0 * dim1 * dim2; + + if (idx < total) { + // Compute source coordinates [d0, d1, d2] + size_t d2 = idx % dim2; + size_t remaining = idx / dim2; + size_t d1 = remaining % dim1; + size_t d0 = remaining / dim1; + + // Compute destination index [d0, d2, d1] + size_t dst_idx = d0 * dim2 * dim1 + d2 * dim1 + d1; + dst[dst_idx] = src[idx]; + } +} + +__global__ void transpose_012_f16_kernel( + const __half* __restrict__ src, + __half* __restrict__ dst, + size_t dim0, + size_t dim1, + size_t dim2 +) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = dim0 * dim1 * dim2; + + if (idx < total) { + size_t d2 = idx % dim2; + size_t remaining = idx / dim2; + size_t d1 = remaining % dim1; + size_t d0 = remaining / dim1; + + size_t dst_idx = d0 * dim2 * dim1 + d2 * dim1 + d1; + dst[dst_idx] = src[idx]; + } +} + +__global__ void transpose_012_bf16_kernel( + const __nv_bfloat16* __restrict__ src, + __nv_bfloat16* __restrict__ dst, + size_t dim0, + size_t dim1, + size_t dim2 +) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = dim0 * dim1 * dim2; + + if (idx < total) { + size_t d2 = idx % dim2; + size_t remaining = idx / dim2; + size_t d1 = remaining % dim1; + size_t d0 = remaining / dim1; + + size_t dst_idx = d0 * dim2 * dim1 + d2 * dim1 + d1; + dst[dst_idx] = src[idx]; + } +} + +// ============================================================================ +// 4D Transpose: [d0, d1, d2, d3] -> [d0, d2, d1, d3] +// Swaps axes 1 and 2 (common in attention: batch, seq, heads, dim -> batch, heads, seq, dim) +// ============================================================================ + +__global__ void transpose_0213_f32_kernel( + const float* __restrict__ src, + float* __restrict__ dst, + size_t dim0, + size_t dim1, + size_t dim2, + size_t dim3 +) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = dim0 * dim1 * dim2 * dim3; + + if (idx < total) { + // Compute source coordinates [d0, d1, d2, d3] + size_t d3 = idx % dim3; + size_t remaining = idx / dim3; + size_t d2 = remaining % dim2; + remaining = remaining / dim2; + size_t d1 = remaining % dim1; + size_t d0 = remaining / dim1; + + // Compute destination index [d0, d2, d1, d3] + size_t dst_idx = d0 * (dim2 * dim1 * dim3) + d2 * (dim1 * dim3) + d1 * dim3 + d3; + dst[dst_idx] = src[idx]; + } +} + +__global__ void transpose_0213_f16_kernel( + const __half* __restrict__ src, + __half* __restrict__ dst, + size_t dim0, + size_t dim1, + size_t dim2, + size_t dim3 +) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = dim0 * dim1 * dim2 * dim3; + + if (idx < total) { + size_t d3 = idx % dim3; + size_t remaining = idx / dim3; + size_t d2 = remaining % dim2; + remaining = remaining / dim2; + size_t d1 = remaining % dim1; + size_t d0 = remaining / dim1; + + size_t dst_idx = d0 * (dim2 * dim1 * dim3) + d2 * (dim1 * dim3) + d1 * dim3 + d3; + dst[dst_idx] = src[idx]; + } +} + +__global__ void transpose_0213_bf16_kernel( + const __nv_bfloat16* __restrict__ src, + __nv_bfloat16* __restrict__ dst, + size_t dim0, + size_t dim1, + size_t dim2, + size_t dim3 +) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = dim0 * dim1 * dim2 * dim3; + + if (idx < total) { + size_t d3 = idx % dim3; + size_t remaining = idx / dim3; + size_t d2 = remaining % dim2; + remaining = remaining / dim2; + size_t d1 = remaining % dim1; + size_t d0 = remaining / dim1; + + size_t dst_idx = d0 * (dim2 * dim1 * dim3) + d2 * (dim1 * dim3) + d1 * dim3 + d3; + dst[dst_idx] = src[idx]; + } +} + +// ============================================================================ +// 4D Transpose: [d0, d1, d2, d3] -> [d0, d1, d3, d2] +// Swaps last two axes (for K^T in attention) +// ============================================================================ + +__global__ void transpose_0132_f32_kernel( + const float* __restrict__ src, + float* __restrict__ dst, + size_t dim0, + size_t dim1, + size_t dim2, + size_t dim3 +) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = dim0 * dim1 * dim2 * dim3; + + if (idx < total) { + // Compute source coordinates [d0, d1, d2, d3] + size_t d3 = idx % dim3; + size_t remaining = idx / dim3; + size_t d2 = remaining % dim2; + remaining = remaining / dim2; + size_t d1 = remaining % dim1; + size_t d0 = remaining / dim1; + + // Compute destination index [d0, d1, d3, d2] + size_t dst_idx = d0 * (dim1 * dim3 * dim2) + d1 * (dim3 * dim2) + d3 * dim2 + d2; + dst[dst_idx] = src[idx]; + } +} + +__global__ void transpose_0132_f16_kernel( + const __half* __restrict__ src, + __half* __restrict__ dst, + size_t dim0, + size_t dim1, + size_t dim2, + size_t dim3 +) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = dim0 * dim1 * dim2 * dim3; + + if (idx < total) { + size_t d3 = idx % dim3; + size_t remaining = idx / dim3; + size_t d2 = remaining % dim2; + remaining = remaining / dim2; + size_t d1 = remaining % dim1; + size_t d0 = remaining / dim1; + + size_t dst_idx = d0 * (dim1 * dim3 * dim2) + d1 * (dim3 * dim2) + d3 * dim2 + d2; + dst[dst_idx] = src[idx]; + } +} + +__global__ void transpose_0132_bf16_kernel( + const __nv_bfloat16* __restrict__ src, + __nv_bfloat16* __restrict__ dst, + size_t dim0, + size_t dim1, + size_t dim2, + size_t dim3 +) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = dim0 * dim1 * dim2 * dim3; + + if (idx < total) { + size_t d3 = idx % dim3; + size_t remaining = idx / dim3; + size_t d2 = remaining % dim2; + remaining = remaining / dim2; + size_t d1 = remaining % dim1; + size_t d0 = remaining / dim1; + + size_t dst_idx = d0 * (dim1 * dim3 * dim2) + d1 * (dim3 * dim2) + d3 * dim2 + d2; + dst[dst_idx] = src[idx]; + } +} + // Reshape with copy (ensures contiguous output) // Simply copies data - reshape is handled by changing shape metadata __global__ void copy_f32_kernel( @@ -398,6 +626,73 @@ __global__ void copy_i32_kernel( } } +// ============================================================================ +// Arange - generate sequence [start, start+step, start+2*step, ...] +// ============================================================================ + +__global__ void arange_f32_kernel(float* output, float start, float step, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = start + static_cast(idx) * step; + } +} + +__global__ void arange_i32_kernel(int* output, int start, int step, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = start + static_cast(idx) * step; + } +} + +__global__ void arange_i64_kernel(int64_t* output, int64_t start, int64_t step, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = start + static_cast(idx) * step; + } +} + +// ============================================================================ +// Scatter Add - indexed accumulation: output[indices[i]] += src[i] +// ============================================================================ + +__global__ void scatter_add_f32_kernel( + float* __restrict__ output, + const int64_t* __restrict__ indices, + const float* __restrict__ src, + size_t n +) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + atomicAdd(&output[indices[idx]], src[idx]); + } +} + +__global__ void scatter_add_f16_kernel( + __half* __restrict__ output, + const int64_t* __restrict__ indices, + const __half* __restrict__ src, + size_t n +) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // FP16 atomicAdd requires sm_70+ + atomicAdd(&output[indices[idx]], src[idx]); + } +} + +__global__ void scatter_add_bf16_kernel( + __nv_bfloat16* __restrict__ output, + const int64_t* __restrict__ indices, + const __nv_bfloat16* __restrict__ src, + size_t n +) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // BF16 atomicAdd requires sm_80+ + atomicAdd(&output[indices[idx]], src[idx]); + } +} + } // namespace nn } // namespace ops } // namespace pygpukit diff --git a/native/ops/nn/nn.cu b/native/ops/nn/nn.cu index 489ab67..fb9be55 100644 --- a/native/ops/nn/nn.cu +++ b/native/ops/nn/nn.cu @@ -817,6 +817,132 @@ void silu(const GPUArray& input, GPUArray& out) { sync_and_check("silu kernel failed"); } +// ============================================================================ +// Sigmoid Activation: 1 / (1 + exp(-x)) +// ============================================================================ + +static void sigmoid_dispatch(const GPUArray& input, GPUArray& result) { + size_t n = input.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + cudaStream_t stream = internal::get_capture_stream(); + + switch (input.dtype()) { + case DataType::Float32: + nn::sigmoid_f32_kernel<<>>( + static_cast(input.data()), + static_cast(result.data()), + n); + break; + case DataType::Float16: + nn::sigmoid_f16_kernel<<>>( + static_cast(input.data()), + static_cast<__half*>(result.data()), + n); + break; + case DataType::BFloat16: + nn::sigmoid_bf16_kernel<<>>( + static_cast(input.data()), + static_cast<__nv_bfloat16*>(result.data()), + n); + break; + default: + break; + } +} + +GPUArray sigmoid(const GPUArray& input) { + if (input.dtype() != DataType::Float32 && + input.dtype() != DataType::Float16 && input.dtype() != DataType::BFloat16) { + throw std::runtime_error("sigmoid only supports float types (f32, f16, bf16)"); + } + + GPUArray result(input.shape(), input.dtype()); + sigmoid_dispatch(input, result); + sync_and_check("sigmoid kernel failed"); + return result; +} + +void sigmoid(const GPUArray& input, GPUArray& out) { + if (input.dtype() != DataType::Float32 && + input.dtype() != DataType::Float16 && input.dtype() != DataType::BFloat16) { + throw std::runtime_error("sigmoid only supports float types (f32, f16, bf16)"); + } + if (input.dtype() != out.dtype()) { + throw std::runtime_error("sigmoid: dtype mismatch between input and output"); + } + if (input.shape() != out.shape()) { + throw std::runtime_error("sigmoid: shape mismatch between input and output"); + } + + sigmoid_dispatch(input, out); + sync_and_check("sigmoid kernel failed"); +} + +// ============================================================================ +// Tanh Activation +// ============================================================================ + +static void tanh_dispatch(const GPUArray& input, GPUArray& result) { + size_t n = input.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + cudaStream_t stream = internal::get_capture_stream(); + + switch (input.dtype()) { + case DataType::Float32: + nn::tanh_f32_kernel<<>>( + static_cast(input.data()), + static_cast(result.data()), + n); + break; + case DataType::Float16: + nn::tanh_f16_kernel<<>>( + static_cast(input.data()), + static_cast<__half*>(result.data()), + n); + break; + case DataType::BFloat16: + nn::tanh_bf16_kernel<<>>( + static_cast(input.data()), + static_cast<__nv_bfloat16*>(result.data()), + n); + break; + default: + break; + } +} + +GPUArray tanh(const GPUArray& input) { + if (input.dtype() != DataType::Float32 && + input.dtype() != DataType::Float16 && input.dtype() != DataType::BFloat16) { + throw std::runtime_error("tanh only supports float types (f32, f16, bf16)"); + } + + GPUArray result(input.shape(), input.dtype()); + tanh_dispatch(input, result); + sync_and_check("tanh kernel failed"); + return result; +} + +void tanh(const GPUArray& input, GPUArray& out) { + if (input.dtype() != DataType::Float32 && + input.dtype() != DataType::Float16 && input.dtype() != DataType::BFloat16) { + throw std::runtime_error("tanh only supports float types (f32, f16, bf16)"); + } + if (input.dtype() != out.dtype()) { + throw std::runtime_error("tanh: dtype mismatch between input and output"); + } + if (input.shape() != out.shape()) { + throw std::runtime_error("tanh: shape mismatch between input and output"); + } + + tanh_dispatch(input, out); + sync_and_check("tanh kernel failed"); +} + // ============================================================================ // Scaled Dot-Product Attention (SDPA) with Causal Mask // ============================================================================ @@ -1436,6 +1562,292 @@ void transpose_3d_021(const GPUArray& input, GPUArray& out) { sync_and_check("transpose_3d_021 kernel failed"); } +// Internal helper for transpose_4d_0213 kernel dispatch +static void transpose_4d_0213_dispatch( + const GPUArray& input, + GPUArray& result, + size_t dim0, size_t dim1, size_t dim2, size_t dim3 +) { + size_t total = input.size(); + const int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + + // Use capture stream if available + cudaStream_t stream = internal::get_capture_stream(); + + switch (input.dtype()) { + case DataType::Float32: + nn::transpose_0213_f32_kernel<<>>( + static_cast(input.data()), + static_cast(result.data()), + dim0, dim1, dim2, dim3); + break; + case DataType::Float16: + nn::transpose_0213_f16_kernel<<>>( + static_cast(input.data()), + static_cast<__half*>(result.data()), + dim0, dim1, dim2, dim3); + break; + case DataType::BFloat16: + nn::transpose_0213_bf16_kernel<<>>( + static_cast(input.data()), + static_cast<__nv_bfloat16*>(result.data()), + dim0, dim1, dim2, dim3); + break; + default: + throw std::runtime_error("transpose_4d_0213: unsupported dtype"); + } +} + +// Transpose 4D tensor: [d0, d1, d2, d3] -> [d0, d2, d1, d3] +GPUArray transpose_4d_0213(const GPUArray& input) { + if (input.dtype() != DataType::Float32 && input.dtype() != DataType::Float16 && + input.dtype() != DataType::BFloat16) { + throw std::runtime_error("transpose_4d_0213: only float32/float16/bfloat16 supported"); + } + if (input.ndim() != 4) { + throw std::runtime_error("transpose_4d_0213: expects 4D tensor"); + } + + size_t dim0 = input.shape()[0]; + size_t dim1 = input.shape()[1]; + size_t dim2 = input.shape()[2]; + size_t dim3 = input.shape()[3]; + + // Output shape: [dim0, dim2, dim1, dim3] + std::vector out_shape = {dim0, dim2, dim1, dim3}; + GPUArray result(out_shape, input.dtype()); + + transpose_4d_0213_dispatch(input, result, dim0, dim1, dim2, dim3); + sync_and_check("transpose_4d_0213 kernel failed"); + return result; +} + +// Transpose 4D tensor with output buffer (for CUDA Graph capture) +void transpose_4d_0213(const GPUArray& input, GPUArray& out) { + if (input.dtype() != DataType::Float32 && input.dtype() != DataType::Float16 && + input.dtype() != DataType::BFloat16) { + throw std::runtime_error("transpose_4d_0213: only float32/float16/bfloat16 supported"); + } + if (input.ndim() != 4) { + throw std::runtime_error("transpose_4d_0213: expects 4D tensor"); + } + if (out.ndim() != 4) { + throw std::runtime_error("transpose_4d_0213: output expects 4D tensor"); + } + if (input.dtype() != out.dtype()) { + throw std::runtime_error("transpose_4d_0213: dtype mismatch"); + } + + size_t dim0 = input.shape()[0]; + size_t dim1 = input.shape()[1]; + size_t dim2 = input.shape()[2]; + size_t dim3 = input.shape()[3]; + + // Verify output shape: [dim0, dim2, dim1, dim3] + if (out.shape()[0] != dim0 || out.shape()[1] != dim2 || + out.shape()[2] != dim1 || out.shape()[3] != dim3) { + throw std::runtime_error("transpose_4d_0213: output shape mismatch, expected [" + + std::to_string(dim0) + ", " + std::to_string(dim2) + ", " + + std::to_string(dim1) + ", " + std::to_string(dim3) + "]"); + } + + transpose_4d_0213_dispatch(input, out, dim0, dim1, dim2, dim3); + sync_and_check("transpose_4d_0213 kernel failed"); +} + +// ============================================================================ +// 3D Transpose: [d0, d1, d2] -> [d0, d2, d1] (swaps last two axes) +// ============================================================================ + +// Internal helper for transpose_3d_012 kernel dispatch +static void transpose_3d_012_dispatch( + const GPUArray& input, + GPUArray& result, + size_t dim0, size_t dim1, size_t dim2 +) { + size_t total = input.size(); + const int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + + // Use capture stream if available + cudaStream_t stream = internal::get_capture_stream(); + + switch (input.dtype()) { + case DataType::Float32: + nn::transpose_012_f32_kernel<<>>( + static_cast(input.data()), + static_cast(result.data()), + dim0, dim1, dim2); + break; + case DataType::Float16: + nn::transpose_012_f16_kernel<<>>( + static_cast(input.data()), + static_cast<__half*>(result.data()), + dim0, dim1, dim2); + break; + case DataType::BFloat16: + nn::transpose_012_bf16_kernel<<>>( + static_cast(input.data()), + static_cast<__nv_bfloat16*>(result.data()), + dim0, dim1, dim2); + break; + default: + throw std::runtime_error("transpose_3d_012: unsupported dtype"); + } +} + +// Transpose 3D tensor: [d0, d1, d2] -> [d0, d2, d1] +GPUArray transpose_3d_012(const GPUArray& input) { + if (input.dtype() != DataType::Float32 && input.dtype() != DataType::Float16 && + input.dtype() != DataType::BFloat16) { + throw std::runtime_error("transpose_3d_012: only float32/float16/bfloat16 supported"); + } + if (input.ndim() != 3) { + throw std::runtime_error("transpose_3d_012: expects 3D tensor"); + } + + size_t dim0 = input.shape()[0]; + size_t dim1 = input.shape()[1]; + size_t dim2 = input.shape()[2]; + + // Output shape: [dim0, dim2, dim1] + std::vector out_shape = {dim0, dim2, dim1}; + GPUArray result(out_shape, input.dtype()); + + transpose_3d_012_dispatch(input, result, dim0, dim1, dim2); + sync_and_check("transpose_3d_012 kernel failed"); + return result; +} + +// Transpose 3D tensor with output buffer (for CUDA Graph capture) +void transpose_3d_012(const GPUArray& input, GPUArray& out) { + if (input.dtype() != DataType::Float32 && input.dtype() != DataType::Float16 && + input.dtype() != DataType::BFloat16) { + throw std::runtime_error("transpose_3d_012: only float32/float16/bfloat16 supported"); + } + if (input.ndim() != 3) { + throw std::runtime_error("transpose_3d_012: expects 3D tensor"); + } + if (out.ndim() != 3) { + throw std::runtime_error("transpose_3d_012: output expects 3D tensor"); + } + if (input.dtype() != out.dtype()) { + throw std::runtime_error("transpose_3d_012: dtype mismatch"); + } + + size_t dim0 = input.shape()[0]; + size_t dim1 = input.shape()[1]; + size_t dim2 = input.shape()[2]; + + // Verify output shape: [dim0, dim2, dim1] + if (out.shape()[0] != dim0 || out.shape()[1] != dim2 || out.shape()[2] != dim1) { + throw std::runtime_error("transpose_3d_012: output shape mismatch, expected [" + + std::to_string(dim0) + ", " + std::to_string(dim2) + ", " + std::to_string(dim1) + "]"); + } + + transpose_3d_012_dispatch(input, out, dim0, dim1, dim2); + sync_and_check("transpose_3d_012 kernel failed"); +} + +// ============================================================================ +// 4D Transpose: [d0, d1, d2, d3] -> [d0, d1, d3, d2] (swaps last two axes) +// ============================================================================ + +// Internal helper for transpose_4d_0132 kernel dispatch +static void transpose_4d_0132_dispatch( + const GPUArray& input, + GPUArray& result, + size_t dim0, size_t dim1, size_t dim2, size_t dim3 +) { + size_t total = input.size(); + const int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + + // Use capture stream if available + cudaStream_t stream = internal::get_capture_stream(); + + switch (input.dtype()) { + case DataType::Float32: + nn::transpose_0132_f32_kernel<<>>( + static_cast(input.data()), + static_cast(result.data()), + dim0, dim1, dim2, dim3); + break; + case DataType::Float16: + nn::transpose_0132_f16_kernel<<>>( + static_cast(input.data()), + static_cast<__half*>(result.data()), + dim0, dim1, dim2, dim3); + break; + case DataType::BFloat16: + nn::transpose_0132_bf16_kernel<<>>( + static_cast(input.data()), + static_cast<__nv_bfloat16*>(result.data()), + dim0, dim1, dim2, dim3); + break; + default: + throw std::runtime_error("transpose_4d_0132: unsupported dtype"); + } +} + +// Transpose 4D tensor: [d0, d1, d2, d3] -> [d0, d1, d3, d2] +GPUArray transpose_4d_0132(const GPUArray& input) { + if (input.dtype() != DataType::Float32 && input.dtype() != DataType::Float16 && + input.dtype() != DataType::BFloat16) { + throw std::runtime_error("transpose_4d_0132: only float32/float16/bfloat16 supported"); + } + if (input.ndim() != 4) { + throw std::runtime_error("transpose_4d_0132: expects 4D tensor"); + } + + size_t dim0 = input.shape()[0]; + size_t dim1 = input.shape()[1]; + size_t dim2 = input.shape()[2]; + size_t dim3 = input.shape()[3]; + + // Output shape: [dim0, dim1, dim3, dim2] + std::vector out_shape = {dim0, dim1, dim3, dim2}; + GPUArray result(out_shape, input.dtype()); + + transpose_4d_0132_dispatch(input, result, dim0, dim1, dim2, dim3); + sync_and_check("transpose_4d_0132 kernel failed"); + return result; +} + +// Transpose 4D tensor with output buffer (for CUDA Graph capture) +void transpose_4d_0132(const GPUArray& input, GPUArray& out) { + if (input.dtype() != DataType::Float32 && input.dtype() != DataType::Float16 && + input.dtype() != DataType::BFloat16) { + throw std::runtime_error("transpose_4d_0132: only float32/float16/bfloat16 supported"); + } + if (input.ndim() != 4) { + throw std::runtime_error("transpose_4d_0132: expects 4D tensor"); + } + if (out.ndim() != 4) { + throw std::runtime_error("transpose_4d_0132: output expects 4D tensor"); + } + if (input.dtype() != out.dtype()) { + throw std::runtime_error("transpose_4d_0132: dtype mismatch"); + } + + size_t dim0 = input.shape()[0]; + size_t dim1 = input.shape()[1]; + size_t dim2 = input.shape()[2]; + size_t dim3 = input.shape()[3]; + + // Verify output shape: [dim0, dim1, dim3, dim2] + if (out.shape()[0] != dim0 || out.shape()[1] != dim1 || + out.shape()[2] != dim3 || out.shape()[3] != dim2) { + throw std::runtime_error("transpose_4d_0132: output shape mismatch, expected [" + + std::to_string(dim0) + ", " + std::to_string(dim1) + ", " + + std::to_string(dim3) + ", " + std::to_string(dim2) + "]"); + } + + transpose_4d_0132_dispatch(input, out, dim0, dim1, dim2, dim3); + sync_and_check("transpose_4d_0132 kernel failed"); +} + // Internal helper for reshape_copy kernel dispatch static void reshape_copy_dispatch( const GPUArray& input, diff --git a/native/ops/ops.cuh b/native/ops/ops.cuh index 3c12a11..bf58f9e 100644 --- a/native/ops/ops.cuh +++ b/native/ops/ops.cuh @@ -34,6 +34,14 @@ GPUArray sub(const GPUArray& a, const GPUArray& b); void div(const GPUArray& a, const GPUArray& b, GPUArray& c); GPUArray div(const GPUArray& a, const GPUArray& b); +// Clamp: c = clamp(a, min_val, max_val) +void clamp(const GPUArray& a, GPUArray& c, float min_val, float max_val); +GPUArray clamp(const GPUArray& a, float min_val, float max_val); + +// Where: c = cond ? a : b (conditional select) +void where(const GPUArray& cond, const GPUArray& a, const GPUArray& b, GPUArray& c); +GPUArray where(const GPUArray& cond, const GPUArray& a, const GPUArray& b); + // ============================================================================ // Unary Operations // ============================================================================ @@ -50,6 +58,30 @@ GPUArray log(const GPUArray& a); void relu(const GPUArray& a, GPUArray& c); GPUArray relu(const GPUArray& a); +// Sin: c = sin(a) +void sin(const GPUArray& a, GPUArray& c); +GPUArray sin(const GPUArray& a); + +// Cos: c = cos(a) +void cos(const GPUArray& a, GPUArray& c); +GPUArray cos(const GPUArray& a); + +// Sqrt: c = sqrt(a) +void sqrt(const GPUArray& a, GPUArray& c); +GPUArray sqrt(const GPUArray& a); + +// Rsqrt: c = 1/sqrt(a) +void rsqrt(const GPUArray& a, GPUArray& c); +GPUArray rsqrt(const GPUArray& a); + +// Abs: c = |a| +void abs(const GPUArray& a, GPUArray& c); +GPUArray abs(const GPUArray& a); + +// Neg: c = -a +void neg(const GPUArray& a, GPUArray& c); +GPUArray neg(const GPUArray& a); + // ============================================================================ // Reduction Operations // ============================================================================ @@ -63,6 +95,16 @@ GPUArray mean(const GPUArray& a); // Max: scalar max of all elements GPUArray max(const GPUArray& a); +// Min: scalar min of all elements +GPUArray min(const GPUArray& a); + +// Argmax: index of maximum element +GPUArray argmax(const GPUArray& a); + +// Sum with axis: sum along specified axis (0 or 1) +// input: [M, N], axis=0 -> output: [N], axis=1 -> output: [M] +GPUArray sum_axis(const GPUArray& a, int axis); + // ============================================================================ // Matrix Multiplication // ============================================================================ @@ -116,6 +158,14 @@ GPUArray silu(const GPUArray& input); // SiLU with output buffer (for CUDA Graph capture) void silu(const GPUArray& input, GPUArray& out); +// Sigmoid activation: y = 1 / (1 + exp(-x)) +GPUArray sigmoid(const GPUArray& input); +void sigmoid(const GPUArray& input, GPUArray& out); + +// Tanh activation +GPUArray tanh(const GPUArray& input); +void tanh(const GPUArray& input, GPUArray& out); + // RoPE (Rotary Position Embedding) - In-place // q: [seq_len, n_heads_q, head_dim] // k: [seq_len, n_heads_k, head_dim] @@ -177,6 +227,13 @@ void sdpa_causal_fixed_cache_ptr(const GPUArray& Q, const GPUArray& K, const GPU // output: [batch, out_features] GPUArray linear_bias_gelu(const GPUArray& input, const GPUArray& weight, const GPUArray& bias); +// Strided Batched GEMM: C[b] = A[b] @ B[b] for b in [0, batch_count) +// A: [batch, M, K], B: [batch, K, N], C: [batch, M, N] (row-major) +// Uses CUTLASS TensorCore for high performance +void batched_matmul_fp32(const GPUArray& A, const GPUArray& B, GPUArray& C, + int M, int N, int K, int batch_count, + int64_t strideA, int64_t strideB, int64_t strideC); + // ============================================================================ // Tensor Manipulation Operations // ============================================================================ @@ -194,6 +251,24 @@ GPUArray transpose_3d_021(const GPUArray& input); // Transpose 3D tensor with output buffer (for CUDA Graph capture) void transpose_3d_021(const GPUArray& input, GPUArray& out); +// Transpose 4D tensor: [d0, d1, d2, d3] -> [d0, d2, d1, d3] +// Swaps axes 1 and 2 (common in attention: batch, seq, heads, dim -> batch, heads, seq, dim) +GPUArray transpose_4d_0213(const GPUArray& input); +// Transpose 4D tensor with output buffer (for CUDA Graph capture) +void transpose_4d_0213(const GPUArray& input, GPUArray& out); + +// Transpose 3D tensor: [d0, d1, d2] -> [d0, d2, d1] +// Swaps last two axes (common in attention operations) +GPUArray transpose_3d_012(const GPUArray& input); +// Transpose 3D tensor with output buffer (for CUDA Graph capture) +void transpose_3d_012(const GPUArray& input, GPUArray& out); + +// Transpose 4D tensor: [d0, d1, d2, d3] -> [d0, d1, d3, d2] +// Swaps last two axes (for K^T in attention) +GPUArray transpose_4d_0132(const GPUArray& input); +// Transpose 4D tensor with output buffer (for CUDA Graph capture) +void transpose_4d_0132(const GPUArray& input, GPUArray& out); + // Reshape with copy (creates contiguous tensor with new shape) GPUArray reshape_copy(const GPUArray& input, const std::vector& new_shape); // Reshape with copy into output buffer (for CUDA Graph capture) diff --git a/native/ops/reduction/reduction.cu b/native/ops/reduction/reduction.cu index f1eb7f7..c821172 100644 --- a/native/ops/reduction/reduction.cu +++ b/native/ops/reduction/reduction.cu @@ -193,5 +193,182 @@ GPUArray max(const GPUArray& a) { return result; } +// ============================================================================ +// Min +// ============================================================================ + +GPUArray min(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("min only supports float types (f32, f16, bf16)"); + } + + GPUArray result({1}, a.dtype()); + size_t n = a.size(); + + const int block_size = 256; + const int max_blocks = 256; + const int grid_size = std::min((int)((n + block_size - 1) / block_size), max_blocks); + + switch (a.dtype()) { + case DataType::Float32: + init_min_f32_kernel<<<1, 1>>>(static_cast(result.data())); + reduce_min_f32_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + break; + case DataType::Float16: + init_min_f16_kernel<<<1, 1>>>(static_cast<__half*>(result.data())); + reduce_min_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(result.data()), + n); + break; + case DataType::BFloat16: + init_min_bf16_kernel<<<1, 1>>>(static_cast<__nv_bfloat16*>(result.data())); + reduce_min_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(result.data()), + n); + break; + default: + break; + } + + sync_and_check("min kernel failed"); + return result; +} + +// ============================================================================ +// Argmax +// ============================================================================ + +GPUArray argmax(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("argmax only supports float types (f32, f16, bf16)"); + } + + GPUArray result({1}, DataType::Int64); + size_t n = a.size(); + + // Single block reduction for simplicity - argmax needs coordination + const int block_size = 256; + const int grid_size = 1; // Single block for global argmax + + switch (a.dtype()) { + case DataType::Float32: + argmax_f32_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + break; + case DataType::Float16: + argmax_f16_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + break; + case DataType::BFloat16: + argmax_bf16_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + break; + default: + break; + } + + sync_and_check("argmax kernel failed"); + return result; +} + +// ============================================================================ +// Sum with axis +// ============================================================================ + +GPUArray sum_axis(const GPUArray& a, int axis) { + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("sum_axis only supports float types (f32, f16, bf16)"); + } + if (a.ndim() != 2) { + throw std::runtime_error("sum_axis only supports 2D tensors"); + } + if (axis != 0 && axis != 1) { + throw std::runtime_error("sum_axis: axis must be 0 or 1"); + } + + int M = a.shape()[0]; + int N = a.shape()[1]; + + std::vector out_shape; + if (axis == 0) { + out_shape = {static_cast(N)}; + } else { + out_shape = {static_cast(M)}; + } + + GPUArray result(out_shape, a.dtype()); + + const int block_size = 256; + + if (axis == 0) { + // Sum along rows -> output [N] + const int grid_size = (N + block_size - 1) / block_size; + switch (a.dtype()) { + case DataType::Float32: + sum_axis0_f32_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + M, N); + break; + case DataType::Float16: + sum_axis0_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(result.data()), + M, N); + break; + case DataType::BFloat16: + sum_axis0_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(result.data()), + M, N); + break; + default: + break; + } + } else { + // Sum along columns -> output [M] + const int grid_size = (M + block_size - 1) / block_size; + switch (a.dtype()) { + case DataType::Float32: + sum_axis1_f32_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + M, N); + break; + case DataType::Float16: + sum_axis1_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(result.data()), + M, N); + break; + case DataType::BFloat16: + sum_axis1_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(result.data()), + M, N); + break; + default: + break; + } + } + + sync_and_check("sum_axis kernel failed"); + return result; +} + } // namespace ops } // namespace pygpukit diff --git a/native/ops/reduction/reduction_kernels.cuh b/native/ops/reduction/reduction_kernels.cuh index 7fa5099..7c5d384 100644 --- a/native/ops/reduction/reduction_kernels.cuh +++ b/native/ops/reduction/reduction_kernels.cuh @@ -324,6 +324,331 @@ __global__ void reduce_max_bf16_kernel(const __nv_bfloat16* __restrict__ input, } } +// ============================================================================ +// Min reduction kernels +// ============================================================================ + +__device__ __forceinline__ float warp_reduce_min(float val) { + for (int offset = 16; offset > 0; offset /= 2) { + val = fminf(val, __shfl_down_sync(0xffffffff, val, offset)); + } + return val; +} + +__global__ void reduce_min_f32_kernel(const float* __restrict__ input, float* __restrict__ output, size_t n) { + __shared__ float shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float min_val = INFINITY; + for (size_t i = idx; i < n; i += stride) { + min_val = fminf(min_val, input[i]); + } + + min_val = warp_reduce_min(min_val); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = min_val; + } + __syncthreads(); + + if (warp_id == 0) { + min_val = (tid < (blockDim.x + 31) / 32) ? shared[lane] : INFINITY; + min_val = warp_reduce_min(min_val); + if (lane == 0) { + int* addr = (int*)output; + int expected = *addr; + while (min_val < __int_as_float(expected)) { + int old = atomicCAS(addr, expected, __float_as_int(min_val)); + if (old == expected) break; + expected = old; + } + } + } +} + +__global__ void reduce_min_f16_kernel(const __half* __restrict__ input, __half* __restrict__ output, size_t n) { + __shared__ float shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float min_val = INFINITY; + for (size_t i = idx; i < n; i += stride) { + min_val = fminf(min_val, __half2float(input[i])); + } + + min_val = warp_reduce_min(min_val); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = min_val; + } + __syncthreads(); + + if (warp_id == 0) { + min_val = (tid < (blockDim.x + 31) / 32) ? shared[lane] : INFINITY; + min_val = warp_reduce_min(min_val); + if (lane == 0) { + float old_val = __half2float(*output); + if (min_val < old_val) { + *output = __float2half(min_val); + } + } + } +} + +__global__ void reduce_min_bf16_kernel(const __nv_bfloat16* __restrict__ input, __nv_bfloat16* __restrict__ output, size_t n) { + __shared__ float shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float min_val = INFINITY; + for (size_t i = idx; i < n; i += stride) { + min_val = fminf(min_val, bf16_to_float(input[i])); + } + + min_val = warp_reduce_min(min_val); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = min_val; + } + __syncthreads(); + + if (warp_id == 0) { + min_val = (tid < (blockDim.x + 31) / 32) ? shared[lane] : INFINITY; + min_val = warp_reduce_min(min_val); + if (lane == 0) { + float old_val = bf16_to_float(*output); + if (min_val < old_val) { + *output = float_to_bf16(min_val); + } + } + } +} + +__global__ void init_min_f32_kernel(float* output) { *output = INFINITY; } +__global__ void init_min_f16_kernel(__half* output) { *output = __float2half(INFINITY); } +__global__ void init_min_bf16_kernel(__nv_bfloat16* output) { *output = float_to_bf16(INFINITY); } + +// ============================================================================ +// Sum with axis kernels - reduce along specified axis +// For 2D tensor [M, N]: axis=0 reduces to [N], axis=1 reduces to [M] +// ============================================================================ + +// Sum along axis 0: [M, N] -> [N] +__global__ void sum_axis0_f32_kernel(const float* __restrict__ input, float* __restrict__ output, + int M, int N) { + int n = blockIdx.x * blockDim.x + threadIdx.x; + if (n >= N) return; + + float sum = 0.0f; + for (int m = 0; m < M; ++m) { + sum += input[m * N + n]; + } + output[n] = sum; +} + +__global__ void sum_axis0_f16_kernel(const __half* __restrict__ input, __half* __restrict__ output, + int M, int N) { + int n = blockIdx.x * blockDim.x + threadIdx.x; + if (n >= N) return; + + float sum = 0.0f; + for (int m = 0; m < M; ++m) { + sum += __half2float(input[m * N + n]); + } + output[n] = __float2half(sum); +} + +__global__ void sum_axis0_bf16_kernel(const __nv_bfloat16* __restrict__ input, __nv_bfloat16* __restrict__ output, + int M, int N) { + int n = blockIdx.x * blockDim.x + threadIdx.x; + if (n >= N) return; + + float sum = 0.0f; + for (int m = 0; m < M; ++m) { + sum += bf16_to_float(input[m * N + n]); + } + output[n] = float_to_bf16(sum); +} + +// Sum along axis 1: [M, N] -> [M] +__global__ void sum_axis1_f32_kernel(const float* __restrict__ input, float* __restrict__ output, + int M, int N) { + int m = blockIdx.x * blockDim.x + threadIdx.x; + if (m >= M) return; + + float sum = 0.0f; + for (int n = 0; n < N; ++n) { + sum += input[m * N + n]; + } + output[m] = sum; +} + +__global__ void sum_axis1_f16_kernel(const __half* __restrict__ input, __half* __restrict__ output, + int M, int N) { + int m = blockIdx.x * blockDim.x + threadIdx.x; + if (m >= M) return; + + float sum = 0.0f; + for (int n = 0; n < N; ++n) { + sum += __half2float(input[m * N + n]); + } + output[m] = __float2half(sum); +} + +__global__ void sum_axis1_bf16_kernel(const __nv_bfloat16* __restrict__ input, __nv_bfloat16* __restrict__ output, + int M, int N) { + int m = blockIdx.x * blockDim.x + threadIdx.x; + if (m >= M) return; + + float sum = 0.0f; + for (int n = 0; n < N; ++n) { + sum += bf16_to_float(input[m * N + n]); + } + output[m] = float_to_bf16(sum); +} + +// ============================================================================ +// Argmax reduction kernels - find index of maximum value +// ============================================================================ + +// Warp-level argmax primitive +__device__ __forceinline__ void warp_reduce_argmax(float& val, int& idx) { + for (int offset = 16; offset > 0; offset /= 2) { + float other_val = __shfl_down_sync(0xffffffff, val, offset); + int other_idx = __shfl_down_sync(0xffffffff, idx, offset); + if (other_val > val) { + val = other_val; + idx = other_idx; + } + } +} + +__global__ void argmax_f32_kernel(const float* __restrict__ input, int64_t* __restrict__ output, size_t n) { + __shared__ float shared_val[32]; + __shared__ int shared_idx[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float max_val = -INFINITY; + int max_idx = 0; + for (size_t i = idx; i < n; i += stride) { + if (input[i] > max_val) { + max_val = input[i]; + max_idx = static_cast(i); + } + } + + warp_reduce_argmax(max_val, max_idx); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared_val[warp_id] = max_val; + shared_idx[warp_id] = max_idx; + } + __syncthreads(); + + if (warp_id == 0) { + max_val = (tid < (blockDim.x + 31) / 32) ? shared_val[lane] : -INFINITY; + max_idx = (tid < (blockDim.x + 31) / 32) ? shared_idx[lane] : 0; + warp_reduce_argmax(max_val, max_idx); + if (lane == 0) { + *output = static_cast(max_idx); + } + } +} + +__global__ void argmax_f16_kernel(const __half* __restrict__ input, int64_t* __restrict__ output, size_t n) { + __shared__ float shared_val[32]; + __shared__ int shared_idx[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float max_val = -INFINITY; + int max_idx = 0; + for (size_t i = idx; i < n; i += stride) { + float v = __half2float(input[i]); + if (v > max_val) { + max_val = v; + max_idx = static_cast(i); + } + } + + warp_reduce_argmax(max_val, max_idx); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared_val[warp_id] = max_val; + shared_idx[warp_id] = max_idx; + } + __syncthreads(); + + if (warp_id == 0) { + max_val = (tid < (blockDim.x + 31) / 32) ? shared_val[lane] : -INFINITY; + max_idx = (tid < (blockDim.x + 31) / 32) ? shared_idx[lane] : 0; + warp_reduce_argmax(max_val, max_idx); + if (lane == 0) { + *output = static_cast(max_idx); + } + } +} + +__global__ void argmax_bf16_kernel(const __nv_bfloat16* __restrict__ input, int64_t* __restrict__ output, size_t n) { + __shared__ float shared_val[32]; + __shared__ int shared_idx[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float max_val = -INFINITY; + int max_idx = 0; + for (size_t i = idx; i < n; i += stride) { + float v = bf16_to_float(input[i]); + if (v > max_val) { + max_val = v; + max_idx = static_cast(i); + } + } + + warp_reduce_argmax(max_val, max_idx); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared_val[warp_id] = max_val; + shared_idx[warp_id] = max_idx; + } + __syncthreads(); + + if (warp_id == 0) { + max_val = (tid < (blockDim.x + 31) / 32) ? shared_val[lane] : -INFINITY; + max_idx = (tid < (blockDim.x + 31) / 32) ? shared_idx[lane] : 0; + warp_reduce_argmax(max_val, max_idx); + if (lane == 0) { + *output = static_cast(max_idx); + } + } +} + // ============================================================================ // Output initialization kernels // ============================================================================ diff --git a/native/ops/unary/unary.cu b/native/ops/unary/unary.cu index 9d6e50f..d56477a 100644 --- a/native/ops/unary/unary.cu +++ b/native/ops/unary/unary.cu @@ -172,5 +172,299 @@ GPUArray relu(const GPUArray& a) { return c; } +// ============================================================================ +// Sin +// ============================================================================ + +void sin(const GPUArray& a, GPUArray& c) { + validate_same_shape(a, c, "sin"); + validate_same_dtype(a, c, "sin"); + + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("sin only supports float types"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + sin_f32_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + sin_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + sin_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + default: + break; + } + sync_and_check("sin kernel failed"); +} + +GPUArray sin(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("sin only supports float types"); + } + GPUArray c(a.shape(), a.dtype()); + sin(a, c); + return c; +} + +// ============================================================================ +// Cos +// ============================================================================ + +void cos(const GPUArray& a, GPUArray& c) { + validate_same_shape(a, c, "cos"); + validate_same_dtype(a, c, "cos"); + + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("cos only supports float types"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + cos_f32_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + cos_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + cos_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + default: + break; + } + sync_and_check("cos kernel failed"); +} + +GPUArray cos(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("cos only supports float types"); + } + GPUArray c(a.shape(), a.dtype()); + cos(a, c); + return c; +} + +// ============================================================================ +// Sqrt +// ============================================================================ + +void sqrt(const GPUArray& a, GPUArray& c) { + validate_same_shape(a, c, "sqrt"); + validate_same_dtype(a, c, "sqrt"); + + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("sqrt only supports float types"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + sqrt_f32_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + sqrt_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + sqrt_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + default: + break; + } + sync_and_check("sqrt kernel failed"); +} + +GPUArray sqrt(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("sqrt only supports float types"); + } + GPUArray c(a.shape(), a.dtype()); + sqrt(a, c); + return c; +} + +// ============================================================================ +// Rsqrt (1/sqrt(x)) +// ============================================================================ + +void rsqrt(const GPUArray& a, GPUArray& c) { + validate_same_shape(a, c, "rsqrt"); + validate_same_dtype(a, c, "rsqrt"); + + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("rsqrt only supports float types"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + rsqrt_f32_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + rsqrt_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + rsqrt_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + default: + break; + } + sync_and_check("rsqrt kernel failed"); +} + +GPUArray rsqrt(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("rsqrt only supports float types"); + } + GPUArray c(a.shape(), a.dtype()); + rsqrt(a, c); + return c; +} + +// ============================================================================ +// Abs +// ============================================================================ + +void abs(const GPUArray& a, GPUArray& c) { + validate_same_shape(a, c, "abs"); + validate_same_dtype(a, c, "abs"); + + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("abs only supports float types"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + abs_f32_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + abs_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + abs_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + default: + break; + } + sync_and_check("abs kernel failed"); +} + +GPUArray abs(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("abs only supports float types"); + } + GPUArray c(a.shape(), a.dtype()); + abs(a, c); + return c; +} + +// ============================================================================ +// Neg (-x) +// ============================================================================ + +void neg(const GPUArray& a, GPUArray& c) { + validate_same_shape(a, c, "neg"); + validate_same_dtype(a, c, "neg"); + + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("neg only supports float types"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + neg_f32_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), n); + break; + case DataType::Float16: + neg_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(c.data()), n); + break; + case DataType::BFloat16: + neg_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(c.data()), n); + break; + default: + break; + } + sync_and_check("neg kernel failed"); +} + +GPUArray neg(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("neg only supports float types"); + } + GPUArray c(a.shape(), a.dtype()); + neg(a, c); + return c; +} + } // namespace ops } // namespace pygpukit diff --git a/native/ops/unary/unary_kernels.cuh b/native/ops/unary/unary_kernels.cuh index a434e4c..7776bf8 100644 --- a/native/ops/unary/unary_kernels.cuh +++ b/native/ops/unary/unary_kernels.cuh @@ -111,6 +111,156 @@ __global__ void relu_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, size_ } } +// ============================================================================ +// Sin kernels +// ============================================================================ + +__global__ void sin_f32_kernel(const float* a, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = sinf(a[idx]); + } +} + +__global__ void sin_f16_kernel(const __half* a, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __float2half(sinf(__half2float(a[idx]))); + } +} + +__global__ void sin_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = float_to_bf16(sinf(bf16_to_float(a[idx]))); + } +} + +// ============================================================================ +// Cos kernels +// ============================================================================ + +__global__ void cos_f32_kernel(const float* a, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = cosf(a[idx]); + } +} + +__global__ void cos_f16_kernel(const __half* a, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __float2half(cosf(__half2float(a[idx]))); + } +} + +__global__ void cos_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = float_to_bf16(cosf(bf16_to_float(a[idx]))); + } +} + +// ============================================================================ +// Sqrt kernels +// ============================================================================ + +__global__ void sqrt_f32_kernel(const float* a, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = sqrtf(a[idx]); + } +} + +__global__ void sqrt_f16_kernel(const __half* a, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __float2half(sqrtf(__half2float(a[idx]))); + } +} + +__global__ void sqrt_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = float_to_bf16(sqrtf(bf16_to_float(a[idx]))); + } +} + +// ============================================================================ +// Rsqrt kernels (reciprocal sqrt: 1/sqrt(x)) +// ============================================================================ + +__global__ void rsqrt_f32_kernel(const float* a, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = rsqrtf(a[idx]); + } +} + +__global__ void rsqrt_f16_kernel(const __half* a, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __float2half(rsqrtf(__half2float(a[idx]))); + } +} + +__global__ void rsqrt_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = float_to_bf16(rsqrtf(bf16_to_float(a[idx]))); + } +} + +// ============================================================================ +// Abs kernels +// ============================================================================ + +__global__ void abs_f32_kernel(const float* a, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = fabsf(a[idx]); + } +} + +__global__ void abs_f16_kernel(const __half* a, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __float2half(fabsf(__half2float(a[idx]))); + } +} + +__global__ void abs_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = float_to_bf16(fabsf(bf16_to_float(a[idx]))); + } +} + +// ============================================================================ +// Neg kernels (negate: -x) +// ============================================================================ + +__global__ void neg_f32_kernel(const float* a, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = -a[idx]; + } +} + +__global__ void neg_f16_kernel(const __half* a, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __hneg(a[idx]); + } +} + +__global__ void neg_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __hneg(a[idx]); + } +} + } // namespace unary } // namespace ops } // namespace pygpukit diff --git a/pyproject.toml b/pyproject.toml index 3ad3e92..8ca2249 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build" [project] name = "PyGPUkit" -version = "0.2.14" +version = "0.2.15" description = "A lightweight GPU runtime for Python with Rust-powered scheduler, NVRTC JIT compilation, and NumPy-like API" readme = "README.md" license = "MIT" @@ -59,10 +59,11 @@ build.targets = [] sdist.include = ["native/*", "rust/*"] sdist.exclude = ["native/build/*", "rust/target/*"] -[tool.scikit-build.cmake.define] -# PyGPUkit requires SM >= 80 (Ampere and newer) for cp.async support -# Default: SM80-90 (CUDA 12.x), SM100+ requires CUDA 13.x and env override -CMAKE_CUDA_ARCHITECTURES = "80;86;89;90" +# [tool.scikit-build.cmake.define] +# SM architectures are controlled via CMAKE_CUDA_ARCHITECTURES: +# - CMakeLists.txt default: "80;86;89;90" (CUDA 12.x compatible) +# - Override via CMAKE_ARGS env var: CMAKE_ARGS=-DCMAKE_CUDA_ARCHITECTURES=120 +# - SM100+ (Blackwell) requires CUDA 12.8+ or 13.x [tool.cibuildwheel] # Skip PyPy, 32-bit builds, and musllinux diff --git a/src/pygpukit/__init__.py b/src/pygpukit/__init__.py index 44c2636..42553f8 100644 --- a/src/pygpukit/__init__.py +++ b/src/pygpukit/__init__.py @@ -1,6 +1,6 @@ """PyGPUkit - A lightweight GPU runtime for Python.""" -__version__ = "0.2.11" +__version__ = "0.2.15" # LLM support (safetensors loader) from pygpukit import llm, ops @@ -41,8 +41,12 @@ warmup, ) from pygpukit.ops.basic import ( + abs, add, + argmax, bias_add_inplace, + clamp, + cos, div, exp, gelu, @@ -52,12 +56,21 @@ matmul, max, mean, + min, mul, + neg, relu, + rsqrt, + sigmoid, + sin, softmax, + sqrt, sub, sum, + sum_axis, + tanh, transpose, + where, ) # Try to import Rust types, fallback to Python implementations @@ -141,25 +154,39 @@ "check_driver_compatibility", # Operations "ops", # ops module for advanced usage + "abs", "add", - "sub", - "mul", + "argmax", + "clamp", + "cos", "div", "exp", - "log", - "relu", "gelu", - "softmax", "layernorm", + "log", "matmul", + "mul", + "neg", + "relu", + "rsqrt", + "sigmoid", + "sin", + "softmax", + "sqrt", + "sub", + "tanh", "transpose", + "where", # Fused operations "bias_add_inplace", "linear_bias_gelu", # Reductions - "sum", - "mean", + "argmax", "max", + "mean", + "min", + "sum", + "sum_axis", # LLM support "llm", # CUDA Graph diff --git a/src/pygpukit/asr/__init__.py b/src/pygpukit/asr/__init__.py new file mode 100644 index 0000000..31a02d3 --- /dev/null +++ b/src/pygpukit/asr/__init__.py @@ -0,0 +1,44 @@ +"""ASR (Automatic Speech Recognition) module for PyGPUkit. + +This module provides GPU-accelerated speech recognition models, +starting with Whisper architecture support. + +Example: + >>> from pygpukit.asr import WhisperModel + >>> model = WhisperModel.from_pretrained("kotoba-tech/kotoba-whisper-v2.0") + >>> result = model.transcribe("audio.wav", language="ja") + >>> print(result.text) +""" + +from .preprocessing import ( + WHISPER_CHUNK_LENGTH, + WHISPER_HOP_LENGTH, + WHISPER_N_FFT, + WHISPER_N_MELS, + WHISPER_SAMPLE_RATE, + normalize_mel, + pad_or_trim, + preprocess_audio, +) +from .whisper import ( + TranscriptionResult, + TranscriptionSegment, + WhisperModel, +) + +__all__ = [ + # High-level API + "WhisperModel", + "TranscriptionResult", + "TranscriptionSegment", + # Preprocessing + "preprocess_audio", + "pad_or_trim", + "normalize_mel", + # Constants + "WHISPER_SAMPLE_RATE", + "WHISPER_N_FFT", + "WHISPER_HOP_LENGTH", + "WHISPER_N_MELS", + "WHISPER_CHUNK_LENGTH", +] diff --git a/src/pygpukit/asr/preprocessing.py b/src/pygpukit/asr/preprocessing.py new file mode 100644 index 0000000..c7f86af --- /dev/null +++ b/src/pygpukit/asr/preprocessing.py @@ -0,0 +1,214 @@ +"""Whisper-compatible audio preprocessing. + +This module provides GPU-accelerated audio preprocessing compatible with +OpenAI Whisper and derived models (kotoba-whisper, faster-whisper, etc.). + +Whisper Preprocessing Pipeline: + 1. Resample to 16kHz (if needed) + 2. Pad/trim to 30 seconds (480,000 samples) + 3. STFT: n_fft=400, hop_length=160, window=hann + 4. Mel filterbank: 80 channels, fmin=0, fmax=8000 + 5. Log-mel: log10(max(mel, 1e-10)) + 6. Normalize: (log_mel + 4.0) / 4.0 + +Reference: + https://github.com/openai/whisper/blob/main/whisper/audio.py +""" + +from typing import Optional, Union + +import numpy as np + +from ..core import GPUArray, from_numpy +from ..ops import audio + +# Whisper audio constants +WHISPER_SAMPLE_RATE = 16000 +WHISPER_N_FFT = 400 +WHISPER_HOP_LENGTH = 160 +WHISPER_N_MELS = 80 +WHISPER_CHUNK_LENGTH = 30 # seconds +WHISPER_N_SAMPLES = WHISPER_SAMPLE_RATE * WHISPER_CHUNK_LENGTH # 480000 +WHISPER_N_FRAMES = WHISPER_N_SAMPLES // WHISPER_HOP_LENGTH # 3000 + + +def pad_or_trim( + audio_data: Union[GPUArray, np.ndarray], + length: int = WHISPER_N_SAMPLES, +) -> GPUArray: + """Pad or trim audio to exact length. + + Args: + audio_data: Input audio samples (float32) + length: Target length in samples (default: 480000 for 30s @ 16kHz) + + Returns: + GPUArray of exact length, zero-padded or trimmed + """ + # Convert to GPUArray if numpy + if isinstance(audio_data, np.ndarray): + audio_data = from_numpy(audio_data.astype(np.float32)) + + current_length = audio_data.shape[0] + + if current_length == length: + return audio_data + + if current_length > length: + # Trim + return audio_data[:length] + else: + # Pad with zeros + pad_length = length - current_length + padding = from_numpy(np.zeros(pad_length, dtype=np.float32)) + # Concatenate on GPU + result_np = np.concatenate([audio_data.to_numpy(), padding.to_numpy()]) + return from_numpy(result_np) + + +def normalize_mel(log_mel: Union[GPUArray, np.ndarray]) -> GPUArray: + """Apply Whisper-style normalization to log-mel spectrogram. + + Whisper normalization: (log_mel + 4.0) / 4.0 + + This centers the values around 0 and scales them to roughly [-1, 1] range. + + Args: + log_mel: Log-mel spectrogram [n_mels, n_frames] or [n_frames, n_mels] + + Returns: + Normalized log-mel spectrogram as GPUArray + """ + # Convert to GPUArray if numpy + if isinstance(log_mel, np.ndarray): + log_mel = from_numpy(log_mel.astype(np.float32)) + + # (log_mel + 4.0) / 4.0 + return (log_mel + 4.0) / 4.0 + + +def preprocess_audio( + audio_input: Union[GPUArray, np.ndarray, str], + sample_rate: Optional[int] = None, + n_mels: int = WHISPER_N_MELS, + padding: bool = True, +) -> GPUArray: + """Preprocess audio for Whisper model inference. + + Complete preprocessing pipeline: + 1. Load audio (if path provided) + 2. Resample to 16kHz (if needed) + 3. Pad/trim to 30 seconds + 4. Compute log-mel spectrogram + 5. Apply Whisper normalization + + Args: + audio_input: Audio samples (GPUArray/ndarray) or file path + sample_rate: Sample rate of input audio (required if not 16kHz) + n_mels: Number of mel bands (default: 80) + padding: Whether to pad short audio to 30s (default: True) + + Returns: + Preprocessed mel spectrogram [n_mels, n_frames] ready for encoder + Shape: [80, 3000] for 30s audio + + Example: + >>> mel = preprocess_audio("audio.wav") + >>> print(mel.shape) # [80, 3000] + >>> # Feed to encoder + >>> encoder_output = encoder(mel.unsqueeze(0)) + """ + # Handle file path input + if isinstance(audio_input, str): + # Load audio file using audio module + audio_buf = audio.load_audio(audio_input) + samples = audio_buf + input_sample_rate = WHISPER_SAMPLE_RATE # Assume load_audio resamples + elif isinstance(audio_input, np.ndarray): + samples = from_numpy(audio_input.astype(np.float32)) + input_sample_rate = sample_rate or WHISPER_SAMPLE_RATE + elif isinstance(audio_input, GPUArray): + samples = audio_input + input_sample_rate = sample_rate or WHISPER_SAMPLE_RATE + else: + raise TypeError(f"Unsupported audio input type: {type(audio_input)}") + + # Resample if needed + if input_sample_rate != WHISPER_SAMPLE_RATE: + samples = audio.resample(samples, input_sample_rate, WHISPER_SAMPLE_RATE) + + # Pad or trim to 30 seconds + if padding: + samples = pad_or_trim(samples, WHISPER_N_SAMPLES) + + # Compute STFT + stft_out = audio.stft( + samples, + n_fft=WHISPER_N_FFT, + hop_length=WHISPER_HOP_LENGTH, + center=True, + ) + + # Compute power spectrum + power = audio.power_spectrum(stft_out) + + # Create and apply mel filterbank + mel_fb = audio.create_mel_filterbank( + n_mels=n_mels, + n_fft=WHISPER_N_FFT, + sample_rate=WHISPER_SAMPLE_RATE, + f_min=0.0, + f_max=8000.0, + ) + mel = audio.apply_mel_filterbank(power, mel_fb) + + # Log-mel + log_mel = audio.log_mel(mel, eps=1e-10) + + # Whisper normalization + normalized = normalize_mel(log_mel) + + # Transpose to [n_mels, n_frames] for encoder input + # Current shape: [n_frames, n_mels] + # Target shape: [n_mels, n_frames] + result_np = normalized.to_numpy().T + return from_numpy(result_np.astype(np.float32)) + + +def preprocess_audio_batch( + audio_list: list, + sample_rate: Optional[int] = None, + n_mels: int = WHISPER_N_MELS, +) -> GPUArray: + """Preprocess multiple audio samples as a batch. + + Args: + audio_list: List of audio samples (GPUArray/ndarray) or file paths + sample_rate: Sample rate of input audio + n_mels: Number of mel bands + + Returns: + Batch of preprocessed mel spectrograms [batch, n_mels, n_frames] + """ + mels = [] + for audio_input in audio_list: + mel = preprocess_audio(audio_input, sample_rate, n_mels) + mels.append(mel.to_numpy()) + + batch = np.stack(mels, axis=0) + return from_numpy(batch) + + +__all__ = [ + "preprocess_audio", + "preprocess_audio_batch", + "pad_or_trim", + "normalize_mel", + "WHISPER_SAMPLE_RATE", + "WHISPER_N_FFT", + "WHISPER_HOP_LENGTH", + "WHISPER_N_MELS", + "WHISPER_CHUNK_LENGTH", + "WHISPER_N_SAMPLES", + "WHISPER_N_FRAMES", +] diff --git a/src/pygpukit/asr/whisper/__init__.py b/src/pygpukit/asr/whisper/__init__.py new file mode 100644 index 0000000..0ff483a --- /dev/null +++ b/src/pygpukit/asr/whisper/__init__.py @@ -0,0 +1,43 @@ +"""Whisper model implementation for PyGPUkit. + +Supports OpenAI Whisper and derived models: +- openai/whisper-large-v3 +- kotoba-tech/kotoba-whisper-v2.0 (Japanese ASR) +- distil-whisper variants + +Example: + >>> from pygpukit.asr.whisper import WhisperModel + >>> model = WhisperModel.from_pretrained("kotoba-tech/kotoba-whisper-v2.0") + >>> result = model.transcribe("audio.wav", language="ja") + >>> print(result.text) +""" + +from .config import WHISPER_CONFIGS, WhisperConfig +from .decoder import WhisperDecoder, WhisperDecoderLayer, create_decoder +from .encoder import WhisperEncoder, WhisperEncoderLayer, create_encoder +from .loader import WhisperWeights, download_model, load_safetensors, load_whisper_model +from .model import TranscriptionResult, TranscriptionSegment, WhisperModel, WhisperTokenizer + +__all__ = [ + # High-level API + "WhisperModel", + "WhisperTokenizer", + "TranscriptionResult", + "TranscriptionSegment", + # Config + "WhisperConfig", + "WHISPER_CONFIGS", + # Loader + "WhisperWeights", + "load_whisper_model", + "load_safetensors", + "download_model", + # Encoder + "WhisperEncoder", + "WhisperEncoderLayer", + "create_encoder", + # Decoder + "WhisperDecoder", + "WhisperDecoderLayer", + "create_decoder", +] diff --git a/src/pygpukit/asr/whisper/config.py b/src/pygpukit/asr/whisper/config.py new file mode 100644 index 0000000..c9a82fe --- /dev/null +++ b/src/pygpukit/asr/whisper/config.py @@ -0,0 +1,253 @@ +"""Whisper model configuration. + +Supports various Whisper variants: +- OpenAI Whisper (tiny, base, small, medium, large, large-v2, large-v3) +- Distilled Whisper (kotoba-whisper, distil-whisper) +""" + +import json +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class WhisperConfig: + """Configuration for Whisper models. + + Attributes: + d_model: Hidden dimension (512-1280 depending on model size) + encoder_layers: Number of encoder transformer layers + decoder_layers: Number of decoder transformer layers + encoder_attention_heads: Number of attention heads in encoder + decoder_attention_heads: Number of attention heads in decoder + encoder_ffn_dim: Feed-forward dimension in encoder + decoder_ffn_dim: Feed-forward dimension in decoder + vocab_size: Vocabulary size (51865 for multilingual, 51864 for English-only) + num_mel_bins: Number of mel spectrogram bins (80 or 128) + max_source_positions: Maximum encoder sequence length (1500 for 30s audio) + max_target_positions: Maximum decoder sequence length (448 tokens) + activation_function: Activation function (gelu) + dropout: Dropout rate + attention_dropout: Attention dropout rate + activation_dropout: Activation dropout rate + bos_token_id: Beginning of sequence token ID + eos_token_id: End of sequence token ID + pad_token_id: Padding token ID + decoder_start_token_id: Decoder start token ID + """ + + # Model architecture + d_model: int = 1280 + encoder_layers: int = 32 + decoder_layers: int = 32 + encoder_attention_heads: int = 20 + decoder_attention_heads: int = 20 + encoder_ffn_dim: int = 5120 + decoder_ffn_dim: int = 5120 + + # Vocabulary + vocab_size: int = 51866 + + # Audio + num_mel_bins: int = 128 # 80 for older Whisper, 128 for large-v3 + + # Sequence lengths + max_source_positions: int = 1500 # 30s audio / 160 hop_length / 2 + max_target_positions: int = 448 + + # Activation and regularization + activation_function: str = "gelu" + dropout: float = 0.0 + attention_dropout: float = 0.0 + activation_dropout: float = 0.0 + + # Special tokens + bos_token_id: int = 50257 + eos_token_id: int = 50257 + pad_token_id: int = 50256 + decoder_start_token_id: int = 50258 + + # Suppress tokens + begin_suppress_tokens: list = field(default_factory=lambda: [220, 50257]) + + # Inference + use_cache: bool = True + torch_dtype: str = "bfloat16" + + # Model name + model_name_or_path: Optional[str] = None + + @classmethod + def from_dict(cls, config_dict: dict) -> "WhisperConfig": + """Create config from dictionary.""" + # Map HuggingFace config keys to our keys + key_mapping = { + "_name_or_path": "model_name_or_path", + } + + mapped_dict = {} + for key, value in config_dict.items(): + mapped_key = key_mapping.get(key, key) + if hasattr(cls, "__dataclass_fields__") and mapped_key in cls.__dataclass_fields__: + mapped_dict[mapped_key] = value + + return cls(**mapped_dict) + + @classmethod + def from_json(cls, json_path: str) -> "WhisperConfig": + """Load config from JSON file.""" + with open(json_path, encoding="utf-8") as f: + config_dict = json.load(f) + return cls.from_dict(config_dict) + + @classmethod + def from_pretrained(cls, model_path: str) -> "WhisperConfig": + """Load config from pretrained model directory or HuggingFace hub.""" + import os + + # Check for local config.json + if os.path.isdir(model_path): + config_path = os.path.join(model_path, "config.json") + if os.path.exists(config_path): + return cls.from_json(config_path) + + # Try HuggingFace hub + try: + from huggingface_hub import hf_hub_download + + config_path = hf_hub_download(repo_id=model_path, filename="config.json") + return cls.from_json(config_path) + except ImportError as err: + raise ImportError( + "huggingface_hub is required to download from HuggingFace. " + "Install with: pip install huggingface_hub" + ) from err + + def to_dict(self) -> dict: + """Convert config to dictionary.""" + return { + "d_model": self.d_model, + "encoder_layers": self.encoder_layers, + "decoder_layers": self.decoder_layers, + "encoder_attention_heads": self.encoder_attention_heads, + "decoder_attention_heads": self.decoder_attention_heads, + "encoder_ffn_dim": self.encoder_ffn_dim, + "decoder_ffn_dim": self.decoder_ffn_dim, + "vocab_size": self.vocab_size, + "num_mel_bins": self.num_mel_bins, + "max_source_positions": self.max_source_positions, + "max_target_positions": self.max_target_positions, + "activation_function": self.activation_function, + "dropout": self.dropout, + "attention_dropout": self.attention_dropout, + "activation_dropout": self.activation_dropout, + "bos_token_id": self.bos_token_id, + "eos_token_id": self.eos_token_id, + "pad_token_id": self.pad_token_id, + "decoder_start_token_id": self.decoder_start_token_id, + } + + @property + def head_dim(self) -> int: + """Dimension per attention head.""" + return self.d_model // self.encoder_attention_heads + + @property + def is_distilled(self) -> bool: + """Check if this is a distilled model (fewer decoder layers).""" + return self.decoder_layers < self.encoder_layers + + def __repr__(self) -> str: + return ( + f"WhisperConfig(\n" + f" d_model={self.d_model},\n" + f" encoder_layers={self.encoder_layers},\n" + f" decoder_layers={self.decoder_layers},\n" + f" attention_heads={self.encoder_attention_heads},\n" + f" ffn_dim={self.encoder_ffn_dim},\n" + f" vocab_size={self.vocab_size},\n" + f" num_mel_bins={self.num_mel_bins},\n" + f" distilled={self.is_distilled}\n" + f")" + ) + + +# Predefined configurations for common Whisper variants +WHISPER_CONFIGS = { + "tiny": WhisperConfig( + d_model=384, + encoder_layers=4, + decoder_layers=4, + encoder_attention_heads=6, + decoder_attention_heads=6, + encoder_ffn_dim=1536, + decoder_ffn_dim=1536, + num_mel_bins=80, + ), + "base": WhisperConfig( + d_model=512, + encoder_layers=6, + decoder_layers=6, + encoder_attention_heads=8, + decoder_attention_heads=8, + encoder_ffn_dim=2048, + decoder_ffn_dim=2048, + num_mel_bins=80, + ), + "small": WhisperConfig( + d_model=768, + encoder_layers=12, + decoder_layers=12, + encoder_attention_heads=12, + decoder_attention_heads=12, + encoder_ffn_dim=3072, + decoder_ffn_dim=3072, + num_mel_bins=80, + ), + "medium": WhisperConfig( + d_model=1024, + encoder_layers=24, + decoder_layers=24, + encoder_attention_heads=16, + decoder_attention_heads=16, + encoder_ffn_dim=4096, + decoder_ffn_dim=4096, + num_mel_bins=80, + ), + "large": WhisperConfig( + d_model=1280, + encoder_layers=32, + decoder_layers=32, + encoder_attention_heads=20, + decoder_attention_heads=20, + encoder_ffn_dim=5120, + decoder_ffn_dim=5120, + num_mel_bins=80, + ), + "large-v3": WhisperConfig( + d_model=1280, + encoder_layers=32, + decoder_layers=32, + encoder_attention_heads=20, + decoder_attention_heads=20, + encoder_ffn_dim=5120, + decoder_ffn_dim=5120, + num_mel_bins=128, # large-v3 uses 128 mel bins + ), + "kotoba-v2": WhisperConfig( + d_model=1280, + encoder_layers=32, + decoder_layers=2, # Distilled! + encoder_attention_heads=20, + decoder_attention_heads=20, + encoder_ffn_dim=5120, + decoder_ffn_dim=5120, + num_mel_bins=128, + ), +} + + +__all__ = [ + "WhisperConfig", + "WHISPER_CONFIGS", +] diff --git a/src/pygpukit/asr/whisper/decoder.py b/src/pygpukit/asr/whisper/decoder.py new file mode 100644 index 0000000..caf3217 --- /dev/null +++ b/src/pygpukit/asr/whisper/decoder.py @@ -0,0 +1,559 @@ +"""Whisper decoder implementation. + +The Whisper decoder generates text tokens from encoder hidden states: +1. Token embedding lookup +2. Sinusoidal positional embeddings +3. N transformer decoder layers: + - Causal self-attention + - Cross-attention to encoder outputs + - FFN +4. Final layer normalization +5. Output projection to vocabulary + +Architecture (Large-v3 / kotoba-whisper-v2.0): +- Input: token IDs [batch, seq_len] +- Encoder states: [batch, 1500, 1280] +- Transformer: 2-32 layers depending on distillation +- Output: logits [batch, seq_len, vocab_size] +""" + +from __future__ import annotations + +import math + +import numpy as np + +from ...core import GPUArray, from_numpy +from ...ops.matmul import matmul +from ...ops.nn import gelu, layernorm +from .config import WhisperConfig +from .loader import WhisperWeights + + +def _softmax_2d(x: GPUArray) -> GPUArray: + """Softmax over last dimension for 2D tensor. + + Args: + x: Input [batch, features] + + Returns: + Softmax output [batch, features] + """ + # Use GPU softmax kernel + from ...ops.reduction import softmax + + return softmax(x) + + +def _softmax_4d(x: GPUArray) -> GPUArray: + """Softmax over last dimension for 4D attention weights. + + Args: + x: Input [batch, heads, seq_q, seq_k] + + Returns: + Softmax output [batch, heads, seq_q, seq_k] + """ + # Use GPU softmax kernel (supports 2D/3D/4D) + from ...ops.reduction import softmax + + return softmax(x) + + +def _batched_matmul(a: GPUArray, b: GPUArray) -> GPUArray: + """Batched matrix multiplication for 4D tensors. + + Args: + a: Input [batch, heads, M, K] + b: Input [batch, heads, K, N] + + Returns: + Output [batch, heads, M, N] + """ + # Use GPU batched matmul kernel + from ...ops.matmul import batched_matmul + + return batched_matmul(a, b) + + +def _create_causal_mask(seq_len: int, dtype: np.dtype) -> np.ndarray: + """Create causal attention mask. + + Args: + seq_len: Sequence length + dtype: Output dtype + + Returns: + Mask [1, 1, seq_len, seq_len] where upper triangle is -inf + """ + mask = np.triu(np.ones((seq_len, seq_len), dtype=dtype) * float("-inf"), k=1) + return mask.reshape(1, 1, seq_len, seq_len) + + +class WhisperDecoderLayer: + """Single Whisper decoder transformer layer. + + Architecture: + x = x + self_attention(layer_norm(x)) + x = x + cross_attention(layer_norm(x), encoder_hidden_states) + x = x + ffn(layer_norm(x)) + """ + + def __init__( + self, + config: WhisperConfig, + layer_weights: dict, + ): + self.config = config + self.d_model = config.d_model + self.n_heads = config.decoder_attention_heads + self.head_dim = config.d_model // config.decoder_attention_heads + + # Load weights as GPUArrays + self._load_weights(layer_weights) + + def _load_weights(self, weights: dict) -> None: + """Load layer weights to GPU.""" + + def _to_gpu(arr): + """Convert numpy array to GPUArray, handling None.""" + return from_numpy(arr) if arr is not None else None + + # Self attention + self.self_attn_q_weight = _to_gpu(weights["self_attn_q_weight"]) + self.self_attn_q_bias = _to_gpu(weights["self_attn_q_bias"]) + self.self_attn_k_weight = _to_gpu(weights["self_attn_k_weight"]) + self.self_attn_k_bias = _to_gpu(weights["self_attn_k_bias"]) + self.self_attn_v_weight = _to_gpu(weights["self_attn_v_weight"]) + self.self_attn_v_bias = _to_gpu(weights["self_attn_v_bias"]) + self.self_attn_out_weight = _to_gpu(weights["self_attn_out_weight"]) + self.self_attn_out_bias = _to_gpu(weights["self_attn_out_bias"]) + + # Self attention layer norm + self.self_attn_ln_weight = _to_gpu(weights["self_attn_layer_norm_weight"]) + self.self_attn_ln_bias = _to_gpu(weights["self_attn_layer_norm_bias"]) + + # Cross attention + self.cross_attn_q_weight = _to_gpu(weights["cross_attn_q_weight"]) + self.cross_attn_q_bias = _to_gpu(weights["cross_attn_q_bias"]) + self.cross_attn_k_weight = _to_gpu(weights["cross_attn_k_weight"]) + self.cross_attn_k_bias = _to_gpu(weights["cross_attn_k_bias"]) + self.cross_attn_v_weight = _to_gpu(weights["cross_attn_v_weight"]) + self.cross_attn_v_bias = _to_gpu(weights["cross_attn_v_bias"]) + self.cross_attn_out_weight = _to_gpu(weights["cross_attn_out_weight"]) + self.cross_attn_out_bias = _to_gpu(weights["cross_attn_out_bias"]) + + # Cross attention layer norm + self.cross_attn_ln_weight = _to_gpu(weights["cross_attn_layer_norm_weight"]) + self.cross_attn_ln_bias = _to_gpu(weights["cross_attn_layer_norm_bias"]) + + # FFN + self.fc1_weight = _to_gpu(weights["fc1_weight"]) + self.fc1_bias = _to_gpu(weights["fc1_bias"]) + self.fc2_weight = _to_gpu(weights["fc2_weight"]) + self.fc2_bias = _to_gpu(weights["fc2_bias"]) + + # Final layer norm + self.ffn_ln_weight = _to_gpu(weights["final_layer_norm_weight"]) + self.ffn_ln_bias = _to_gpu(weights["final_layer_norm_bias"]) + + def __call__( + self, + x: GPUArray, + encoder_hidden_states: GPUArray, + causal_mask: GPUArray | None = None, + ) -> GPUArray: + """Forward pass through decoder layer. + + Args: + x: Input tensor [batch, seq_len, d_model] + encoder_hidden_states: Encoder output [batch, enc_seq_len, d_model] + causal_mask: Optional causal mask [1, 1, seq_len, seq_len] + + Returns: + Output tensor [batch, seq_len, d_model] + """ + # Self attention block (with causal masking) + residual = x + x = self._layer_norm(x, self.self_attn_ln_weight, self.self_attn_ln_bias) + x = self._self_attention(x, causal_mask) + x = residual + x + + # Cross attention block + residual = x + x = self._layer_norm(x, self.cross_attn_ln_weight, self.cross_attn_ln_bias) + x = self._cross_attention(x, encoder_hidden_states) + x = residual + x + + # FFN block + residual = x + x = self._layer_norm(x, self.ffn_ln_weight, self.ffn_ln_bias) + x = self._ffn(x) + x = residual + x + + return x + + def _layer_norm( + self, x: GPUArray, weight: GPUArray, bias: GPUArray, eps: float = 1e-5 + ) -> GPUArray: + """Apply layer normalization.""" + return layernorm(x, weight, bias, eps=eps) + + def _self_attention(self, x: GPUArray, causal_mask: GPUArray | None = None) -> GPUArray: + """Causal multi-head self attention. + + Args: + x: Input [batch, seq_len, d_model] + causal_mask: Causal mask [1, 1, seq_len, seq_len] + + Returns: + Attention output [batch, seq_len, d_model] + """ + batch_size = x.shape[0] + seq_len = x.shape[1] + + # Project Q, K, V + q = self._linear(x, self.self_attn_q_weight, self.self_attn_q_bias) + k = self._linear(x, self.self_attn_k_weight, self.self_attn_k_bias) + v = self._linear(x, self.self_attn_v_weight, self.self_attn_v_bias) + + # Reshape for multi-head attention: [batch, seq, n_heads, head_dim] + q = q.reshape(batch_size, seq_len, self.n_heads, self.head_dim) + k = k.reshape(batch_size, seq_len, self.n_heads, self.head_dim) + v = v.reshape(batch_size, seq_len, self.n_heads, self.head_dim) + + # Transpose to [batch, n_heads, seq, head_dim] + q = q.transpose(0, 2, 1, 3) + k = k.transpose(0, 2, 1, 3) + v = v.transpose(0, 2, 1, 3) + + # Scaled dot-product attention with causal mask + scale = 1.0 / math.sqrt(self.head_dim) + attn_weights = _batched_matmul(q, k.transpose(0, 1, 3, 2)) * scale + + # Apply causal mask + if causal_mask is not None: + attn_weights = attn_weights + causal_mask + + # Softmax + attn_weights = _softmax_4d(attn_weights) + + # Apply attention to values + attn_output = _batched_matmul(attn_weights, v) + + # Reshape back: [batch, n_heads, seq, head_dim] -> [batch, seq, d_model] + attn_output = attn_output.transpose(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, seq_len, self.d_model) + + # Output projection + output = self._linear(attn_output, self.self_attn_out_weight, self.self_attn_out_bias) + + return output + + def _cross_attention(self, x: GPUArray, encoder_hidden_states: GPUArray) -> GPUArray: + """Cross attention to encoder outputs. + + Args: + x: Decoder input [batch, dec_seq_len, d_model] + encoder_hidden_states: Encoder output [batch, enc_seq_len, d_model] + + Returns: + Attention output [batch, dec_seq_len, d_model] + """ + batch_size = x.shape[0] + dec_seq_len = x.shape[1] + enc_seq_len = encoder_hidden_states.shape[1] + + # Q from decoder, K/V from encoder + q = self._linear(x, self.cross_attn_q_weight, self.cross_attn_q_bias) + k = self._linear(encoder_hidden_states, self.cross_attn_k_weight, self.cross_attn_k_bias) + v = self._linear(encoder_hidden_states, self.cross_attn_v_weight, self.cross_attn_v_bias) + + # Reshape for multi-head attention + q = q.reshape(batch_size, dec_seq_len, self.n_heads, self.head_dim) + k = k.reshape(batch_size, enc_seq_len, self.n_heads, self.head_dim) + v = v.reshape(batch_size, enc_seq_len, self.n_heads, self.head_dim) + + # Transpose to [batch, n_heads, seq, head_dim] + q = q.transpose(0, 2, 1, 3) + k = k.transpose(0, 2, 1, 3) + v = v.transpose(0, 2, 1, 3) + + # Scaled dot-product attention (no causal mask for cross attention) + scale = 1.0 / math.sqrt(self.head_dim) + attn_weights = _batched_matmul(q, k.transpose(0, 1, 3, 2)) * scale + + # Softmax + attn_weights = _softmax_4d(attn_weights) + + # Apply attention to values + attn_output = _batched_matmul(attn_weights, v) + + # Reshape back: [batch, n_heads, seq, head_dim] -> [batch, seq, d_model] + attn_output = attn_output.transpose(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, dec_seq_len, self.d_model) + + # Output projection + output = self._linear(attn_output, self.cross_attn_out_weight, self.cross_attn_out_bias) + + return output + + def _ffn(self, x: GPUArray) -> GPUArray: + """Feed-forward network with GELU activation. + + Args: + x: Input [batch, seq_len, d_model] + + Returns: + FFN output [batch, seq_len, d_model] + """ + # fc1: d_model -> ffn_dim + h = self._linear(x, self.fc1_weight, self.fc1_bias) + + # GELU activation + h = gelu(h) + + # fc2: ffn_dim -> d_model + output = self._linear(h, self.fc2_weight, self.fc2_bias) + + return output + + def _linear(self, x: GPUArray, weight: GPUArray, bias: GPUArray) -> GPUArray: + """Linear projection: y = xW^T + b. + + Handles both 2D [batch, features] and 3D [batch, seq_len, features] input. + """ + weight_t = weight.T + out_features = weight.shape[0] + + if x.ndim == 3: + batch, seq_len, in_features = x.shape + x_2d = x.reshape(batch * seq_len, in_features) + out_2d = matmul(x_2d, weight_t) + # Add bias in 2D (broadcasting works naturally) + if bias is not None: + out_2d = out_2d + bias + out = out_2d.reshape(batch, seq_len, out_features) + else: + out = matmul(x, weight_t) + if bias is not None: + out = out + bias + return out + + +class WhisperDecoder: + """Whisper text decoder. + + Generates text tokens from encoder hidden states using + autoregressive decoding. + """ + + def __init__(self, config: WhisperConfig, weights: WhisperWeights): + self.config = config + self.d_model = config.d_model + self.n_layers = config.decoder_layers + self.vocab_size = config.vocab_size + + # Load weights + self._load_weights(weights) + + # Create decoder layers + self.layers = [] + for layer_weights in weights.decoder_layers: + layer = WhisperDecoderLayer(config, layer_weights) + self.layers.append(layer) + + # Cached causal mask + self._cached_mask: GPUArray | None = None + self._cached_mask_size: int = 0 + + def _load_weights(self, weights: WhisperWeights) -> None: + """Load decoder-specific weights.""" + + def _to_gpu(arr): + """Convert numpy array to GPUArray, handling None.""" + return from_numpy(arr) if arr is not None else None + + # Token embeddings + self.embed_tokens = _to_gpu(weights.decoder_embed_tokens) + + # Positional embeddings + self.embed_positions = _to_gpu(weights.decoder_embed_positions) + + # Final layer norm + self.layer_norm_weight = _to_gpu(weights.decoder_layer_norm_weight) + self.layer_norm_bias = _to_gpu(weights.decoder_layer_norm_bias) + + # Output projection + self.proj_out = _to_gpu(weights.proj_out_weight) + + def __call__( + self, + input_ids: GPUArray, + encoder_hidden_states: GPUArray, + past_key_values: list | None = None, + ) -> GPUArray: + """Decode tokens given encoder outputs. + + Args: + input_ids: Token IDs [batch, seq_len] + encoder_hidden_states: Encoder output [batch, enc_seq_len, d_model] + past_key_values: Optional cached key/values for incremental decoding + + Returns: + Logits [batch, seq_len, vocab_size] + """ + seq_len = input_ids.shape[1] + + # Token embedding lookup + x = self._embed_tokens(input_ids) + + # Add positional embeddings + positions = self.embed_positions[:seq_len] + # Add batch dimension for broadcasting: [seq_len, d_model] -> [1, seq_len, d_model] + positions = positions.reshape(1, seq_len, -1) + x = x + positions + + # Get causal mask + causal_mask = self._get_causal_mask(seq_len, x.to_numpy().dtype) + + # Transformer layers + for layer in self.layers: + x = layer(x, encoder_hidden_states, causal_mask) + + # Final layer norm + x = layernorm(x, self.layer_norm_weight, self.layer_norm_bias) + + # Output projection to vocabulary + # x is [batch, seq_len, d_model], proj_out is [vocab_size, d_model] + batch, seq_len, d_model = x.shape + x_2d = x.reshape(batch * seq_len, d_model) + logits_2d = matmul(x_2d, self.proj_out.T) + logits = logits_2d.reshape(batch, seq_len, -1) + + return logits + + def _embed_tokens(self, input_ids: GPUArray) -> GPUArray: + """Lookup token embeddings. + + Args: + input_ids: Token IDs [batch, seq_len] + + Returns: + Embeddings [batch, seq_len, d_model] + """ + # CPU fallback implementation + ids: np.ndarray = input_ids.to_numpy().astype(np.int64) + embed = self.embed_tokens.to_numpy() + + batch_size, seq_len = ids.shape + output = np.zeros((batch_size, seq_len, embed.shape[1]), dtype=embed.dtype) + + for b in range(batch_size): + for s in range(seq_len): + output[b, s] = embed[ids[b, s]] + + return from_numpy(output) + + def _get_causal_mask(self, seq_len: int, dtype: np.dtype) -> GPUArray: + """Get or create causal attention mask. + + Args: + seq_len: Sequence length + dtype: Mask dtype + + Returns: + Causal mask [1, 1, seq_len, seq_len] + """ + if self._cached_mask is None or self._cached_mask_size < seq_len: + mask = _create_causal_mask(seq_len, dtype) + self._cached_mask = from_numpy(mask) + self._cached_mask_size = seq_len + return self._cached_mask + + # Slice cached mask if needed + if self._cached_mask_size > seq_len: + mask = self._cached_mask.to_numpy()[:, :, :seq_len, :seq_len] + return from_numpy(mask) + + return self._cached_mask + + def generate( + self, + encoder_hidden_states: GPUArray, + max_length: int = 448, + temperature: float = 1.0, + top_k: int | None = None, + ) -> list[int]: + """Generate tokens autoregressively. + + Args: + encoder_hidden_states: Encoder output [1, enc_seq_len, d_model] + max_length: Maximum number of tokens to generate + temperature: Sampling temperature + top_k: Optional top-k sampling + + Returns: + List of generated token IDs + """ + # Start with decoder start token + tokens = [self.config.decoder_start_token_id] + + for _ in range(max_length - 1): + # Create input tensor + input_ids = from_numpy(np.array([tokens], dtype=np.int64)) + + # Forward pass + logits = self(input_ids, encoder_hidden_states) + + # Get logits for last token + last_logits = logits.to_numpy()[0, -1, :] # [vocab_size] + + # Apply temperature (skip for greedy decoding) + if temperature > 0.0 and temperature != 1.0: + last_logits = last_logits / temperature + + # Sample next token + if top_k is not None: + # Top-k sampling + top_k_idx = np.argsort(last_logits)[-top_k:] + top_k_logits = last_logits[top_k_idx] + probs = np.exp(top_k_logits - np.max(top_k_logits)) + probs = probs / probs.sum() + next_token = top_k_idx[np.random.choice(len(top_k_idx), p=probs)] + else: + # Greedy decoding + next_token = int(np.argmax(last_logits)) + + tokens.append(next_token) + + # Check for end of sequence + if next_token == self.config.eos_token_id: + break + + return tokens + + +def create_decoder(config: WhisperConfig, weights: WhisperWeights) -> WhisperDecoder: + """Create Whisper decoder from config and weights. + + Args: + config: Whisper model configuration + weights: Loaded model weights + + Returns: + Initialized WhisperDecoder + + Example: + >>> config, weights = load_whisper_model("kotoba-tech/kotoba-whisper-v2.0") + >>> decoder = create_decoder(config, weights) + >>> logits = decoder(input_ids, encoder_hidden_states) + """ + return WhisperDecoder(config, weights) + + +__all__ = [ + "WhisperDecoder", + "WhisperDecoderLayer", + "create_decoder", +] diff --git a/src/pygpukit/asr/whisper/encoder.py b/src/pygpukit/asr/whisper/encoder.py new file mode 100644 index 0000000..07d4c0d --- /dev/null +++ b/src/pygpukit/asr/whisper/encoder.py @@ -0,0 +1,411 @@ +"""Whisper encoder implementation. + +The Whisper encoder processes mel spectrograms through: +1. Conv1d stem (2 layers with GELU activation) +2. Sinusoidal positional embeddings +3. N transformer encoder layers (self-attention + FFN) +4. Final layer normalization + +Architecture (Large-v3 / kotoba-whisper-v2.0): +- Input: [batch, n_mels, n_frames] = [batch, 128, 3000] +- Conv1d: 128 -> 1280 channels +- Transformer: 32 layers, 20 heads, 1280 dim +- Output: [batch, 1500, 1280] +""" + +import math + +import numpy as np + +from ...core import GPUArray, from_numpy +from ...ops.matmul import matmul +from ...ops.nn import gelu, layernorm +from .config import WhisperConfig +from .loader import WhisperWeights + + +def _softmax_4d(x: GPUArray) -> GPUArray: + """Softmax over last dimension for 4D attention weights. + + Args: + x: Input [batch, heads, seq_q, seq_k] + + Returns: + Softmax output [batch, heads, seq_q, seq_k] + """ + # Use GPU softmax kernel (supports 2D/3D/4D) + from ...ops.reduction import softmax + + return softmax(x) + + +def _batched_matmul(a: GPUArray, b: GPUArray) -> GPUArray: + """Batched matrix multiplication for 4D tensors. + + Args: + a: Input [batch, heads, M, K] + b: Input [batch, heads, K, N] + + Returns: + Output [batch, heads, M, N] + """ + # Use GPU batched matmul kernel + from ...ops.matmul import batched_matmul + + return batched_matmul(a, b) + + +def _conv1d( + x: GPUArray, + weight: GPUArray, + bias: GPUArray, + stride: int = 1, + padding: int = 0, +) -> GPUArray: + """1D convolution using im2col + matmul. + + Args: + x: Input [batch, in_channels, length] + weight: Kernel [out_channels, in_channels, kernel_size] + bias: Bias [out_channels] + stride: Stride + padding: Padding + + Returns: + Output [batch, out_channels, out_length] + """ + # CPU fallback implementation using im2col + # TODO: Implement native GPU conv1d kernel + x_np = x.to_numpy() + w_np = weight.to_numpy() + b_np = bias.to_numpy() if bias is not None else None + + batch, in_channels, length = x_np.shape + out_channels, _, kernel_size = w_np.shape + + # Apply padding + if padding > 0: + x_np = np.pad(x_np, ((0, 0), (0, 0), (padding, padding)), mode="constant") + + # Compute output length + out_length = (x_np.shape[2] - kernel_size) // stride + 1 + + # im2col: extract patches + # Shape: [batch, in_channels * kernel_size, out_length] + col = np.zeros((batch, in_channels * kernel_size, out_length), dtype=x_np.dtype) + for i in range(out_length): + start = i * stride + end = start + kernel_size + col[:, :, i] = x_np[:, :, start:end].reshape(batch, -1) + + # matmul: weight [out_channels, in_channels * kernel_size] @ col + # Result: [batch, out_channels, out_length] + w_flat = w_np.reshape(out_channels, -1) # [out_channels, in_channels * kernel_size] + out = np.zeros((batch, out_channels, out_length), dtype=x_np.dtype) + for b in range(batch): + out[b] = w_flat @ col[b] + + # Add bias + if b_np is not None: + out = out + b_np.reshape(1, -1, 1) + + return from_numpy(out) + + +class WhisperEncoderLayer: + """Single Whisper encoder transformer layer. + + Architecture: + x = x + self_attention(layer_norm(x)) + x = x + ffn(layer_norm(x)) + """ + + def __init__( + self, + config: WhisperConfig, + layer_weights: dict, + ): + self.config = config + self.d_model = config.d_model + self.n_heads = config.encoder_attention_heads + self.head_dim = config.d_model // config.encoder_attention_heads + + # Load weights as GPUArrays + self._load_weights(layer_weights) + + def _load_weights(self, weights: dict) -> None: + """Load layer weights to GPU.""" + + def _to_gpu(arr): + """Convert numpy array to GPUArray, handling None.""" + return from_numpy(arr) if arr is not None else None + + # Self attention + self.q_weight = _to_gpu(weights["self_attn_q_weight"]) + self.q_bias = _to_gpu(weights["self_attn_q_bias"]) + self.k_weight = _to_gpu(weights["self_attn_k_weight"]) + self.k_bias = _to_gpu(weights["self_attn_k_bias"]) + self.v_weight = _to_gpu(weights["self_attn_v_weight"]) + self.v_bias = _to_gpu(weights["self_attn_v_bias"]) + self.out_weight = _to_gpu(weights["self_attn_out_weight"]) + self.out_bias = _to_gpu(weights["self_attn_out_bias"]) + + # Self attention layer norm + self.attn_ln_weight = _to_gpu(weights["self_attn_layer_norm_weight"]) + self.attn_ln_bias = _to_gpu(weights["self_attn_layer_norm_bias"]) + + # FFN + self.fc1_weight = _to_gpu(weights["fc1_weight"]) + self.fc1_bias = _to_gpu(weights["fc1_bias"]) + self.fc2_weight = _to_gpu(weights["fc2_weight"]) + self.fc2_bias = _to_gpu(weights["fc2_bias"]) + + # Final layer norm + self.ffn_ln_weight = _to_gpu(weights["final_layer_norm_weight"]) + self.ffn_ln_bias = _to_gpu(weights["final_layer_norm_bias"]) + + def __call__(self, x: GPUArray) -> GPUArray: + """Forward pass through encoder layer. + + Args: + x: Input tensor [batch, seq_len, d_model] + + Returns: + Output tensor [batch, seq_len, d_model] + """ + # Self attention block + residual = x + x = self._layer_norm(x, self.attn_ln_weight, self.attn_ln_bias) + x = self._self_attention(x) + x = residual + x + + # FFN block + residual = x + x = self._layer_norm(x, self.ffn_ln_weight, self.ffn_ln_bias) + x = self._ffn(x) + x = residual + x + + return x + + def _layer_norm( + self, x: GPUArray, weight: GPUArray, bias: GPUArray, eps: float = 1e-5 + ) -> GPUArray: + """Apply layer normalization.""" + return layernorm(x, weight, bias, eps=eps) + + def _self_attention(self, x: GPUArray) -> GPUArray: + """Multi-head self attention. + + Args: + x: Input [batch, seq_len, d_model] + + Returns: + Attention output [batch, seq_len, d_model] + """ + batch_size = x.shape[0] + seq_len = x.shape[1] + + # Project Q, K, V + q = self._linear(x, self.q_weight, self.q_bias) + k = self._linear(x, self.k_weight, self.k_bias) + v = self._linear(x, self.v_weight, self.v_bias) + + # Reshape for multi-head attention: [batch, seq, n_heads, head_dim] + q = q.reshape(batch_size, seq_len, self.n_heads, self.head_dim) + k = k.reshape(batch_size, seq_len, self.n_heads, self.head_dim) + v = v.reshape(batch_size, seq_len, self.n_heads, self.head_dim) + + # Transpose to [batch, n_heads, seq, head_dim] + q = q.transpose(0, 2, 1, 3) + k = k.transpose(0, 2, 1, 3) + v = v.transpose(0, 2, 1, 3) + + # Scaled dot-product attention + scale = 1.0 / math.sqrt(self.head_dim) + attn_weights = _batched_matmul(q, k.transpose(0, 1, 3, 2)) * scale + + # Softmax over last dimension + attn_weights = _softmax_4d(attn_weights) + + # Apply attention to values + attn_output = _batched_matmul(attn_weights, v) + + # Reshape back: [batch, n_heads, seq, head_dim] -> [batch, seq, d_model] + attn_output = attn_output.transpose(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, seq_len, self.d_model) + + # Output projection + output = self._linear(attn_output, self.out_weight, self.out_bias) + + return output + + def _ffn(self, x: GPUArray) -> GPUArray: + """Feed-forward network with GELU activation. + + Args: + x: Input [batch, seq_len, d_model] + + Returns: + FFN output [batch, seq_len, d_model] + """ + # fc1: d_model -> ffn_dim + h = self._linear(x, self.fc1_weight, self.fc1_bias) + + # GELU activation + h = gelu(h) + + # fc2: ffn_dim -> d_model + output = self._linear(h, self.fc2_weight, self.fc2_bias) + + return output + + def _linear(self, x: GPUArray, weight: GPUArray, bias: GPUArray) -> GPUArray: + """Linear projection: y = xW^T + b. + + Handles both 2D [batch, features] and 3D [batch, seq_len, features] input. + """ + # weight is [out_features, in_features], need to transpose + weight_t = weight.T + out_features = weight.shape[0] + + if x.ndim == 3: + # Reshape [batch, seq_len, in_features] -> [batch * seq_len, in_features] + batch, seq_len, in_features = x.shape + x_2d = x.reshape(batch * seq_len, in_features) + out_2d = matmul(x_2d, weight_t) + # Add bias in 2D (broadcasting works naturally) + if bias is not None: + out_2d = out_2d + bias + out = out_2d.reshape(batch, seq_len, out_features) + else: + out = matmul(x, weight_t) + if bias is not None: + out = out + bias + return out + + +class WhisperEncoder: + """Whisper audio encoder. + + Converts mel spectrograms to encoder hidden states. + """ + + def __init__(self, config: WhisperConfig, weights: WhisperWeights): + self.config = config + self.d_model = config.d_model + self.n_layers = config.encoder_layers + + # Load weights + self._load_weights(weights) + + # Create encoder layers + self.layers = [] + for layer_weights in weights.encoder_layers: + layer = WhisperEncoderLayer(config, layer_weights) + self.layers.append(layer) + + def _load_weights(self, weights: WhisperWeights) -> None: + """Load encoder-specific weights.""" + + def _to_gpu(arr): + """Convert numpy array to GPUArray, handling None.""" + return from_numpy(arr) if arr is not None else None + + # Conv1d stem + self.conv1_weight = _to_gpu(weights.encoder_conv1_weight) + self.conv1_bias = _to_gpu(weights.encoder_conv1_bias) + self.conv2_weight = _to_gpu(weights.encoder_conv2_weight) + self.conv2_bias = _to_gpu(weights.encoder_conv2_bias) + + # Positional embeddings + self.embed_positions = _to_gpu(weights.encoder_embed_positions) + + # Final layer norm + self.layer_norm_weight = _to_gpu(weights.encoder_layer_norm_weight) + self.layer_norm_bias = _to_gpu(weights.encoder_layer_norm_bias) + + def __call__(self, mel: GPUArray) -> GPUArray: + """Encode mel spectrogram to hidden states. + + Args: + mel: Mel spectrogram [batch, n_mels, n_frames] + For kotoba-whisper: [batch, 128, 3000] + + Returns: + Encoder hidden states [batch, seq_len, d_model] + For kotoba-whisper: [batch, 1500, 1280] + """ + # Conv1d stem: [batch, n_mels, n_frames] -> [batch, d_model, seq_len] + x = self._conv_stem(mel) + + # Transpose to [batch, seq_len, d_model] + x = x.transpose(0, 2, 1) + + # Add positional embeddings + seq_len = x.shape[1] + max_positions = self.embed_positions.shape[0] + if seq_len > max_positions: + # Clamp to available positions (should not happen with correct preprocessing) + seq_len = max_positions + x = x[:, :seq_len, :] + positions = self.embed_positions[:seq_len] + # Add batch dimension for broadcasting: [seq_len, d_model] -> [1, seq_len, d_model] + positions = positions.reshape(1, seq_len, -1) + x = x + positions + + # Transformer layers + for layer in self.layers: + x = layer(x) + + # Final layer norm + x = layernorm(x, self.layer_norm_weight, self.layer_norm_bias) + + return x + + def _conv_stem(self, mel: GPUArray) -> GPUArray: + """Convolutional stem: 2 Conv1d layers with GELU. + + Conv1: n_mels -> d_model, kernel=3, padding=1 + Conv2: d_model -> d_model, kernel=3, stride=2, padding=1 + + Args: + mel: [batch, n_mels, n_frames] + + Returns: + [batch, d_model, n_frames // 2] + """ + # Conv1: [batch, n_mels, n_frames] -> [batch, d_model, n_frames] + x = _conv1d(mel, self.conv1_weight, self.conv1_bias, padding=1) + x = gelu(x) + + # Conv2: [batch, d_model, n_frames] -> [batch, d_model, n_frames // 2] + x = _conv1d(x, self.conv2_weight, self.conv2_bias, stride=2, padding=1) + x = gelu(x) + + return x + + +def create_encoder(config: WhisperConfig, weights: WhisperWeights) -> WhisperEncoder: + """Create Whisper encoder from config and weights. + + Args: + config: Whisper model configuration + weights: Loaded model weights + + Returns: + Initialized WhisperEncoder + + Example: + >>> config, weights = load_whisper_model("kotoba-tech/kotoba-whisper-v2.0") + >>> encoder = create_encoder(config, weights) + >>> mel = preprocess_audio("audio.wav") # [80, 3000] + >>> hidden = encoder(mel.unsqueeze(0)) # [1, 1500, 1280] + """ + return WhisperEncoder(config, weights) + + +__all__ = [ + "WhisperEncoder", + "WhisperEncoderLayer", + "create_encoder", +] diff --git a/src/pygpukit/asr/whisper/loader.py b/src/pygpukit/asr/whisper/loader.py new file mode 100644 index 0000000..52aa648 --- /dev/null +++ b/src/pygpukit/asr/whisper/loader.py @@ -0,0 +1,398 @@ +"""Whisper model loader for SafeTensors format. + +Loads Whisper models from HuggingFace format (SafeTensors) and maps +tensor names to PyGPUkit internal structure. + +Tensor naming convention in HuggingFace Whisper: + model.encoder.conv1.weight + model.encoder.conv2.weight + model.encoder.embed_positions.weight + model.encoder.layers.{i}.self_attn.{k,v,q,out}_proj.{weight,bias} + model.encoder.layers.{i}.self_attn_layer_norm.{weight,bias} + model.encoder.layers.{i}.fc1.{weight,bias} + model.encoder.layers.{i}.fc2.{weight,bias} + model.encoder.layers.{i}.final_layer_norm.{weight,bias} + model.encoder.layer_norm.{weight,bias} + model.decoder.embed_tokens.weight + model.decoder.embed_positions.weight + model.decoder.layers.{i}.self_attn.{k,v,q,out}_proj.{weight,bias} + model.decoder.layers.{i}.self_attn_layer_norm.{weight,bias} + model.decoder.layers.{i}.encoder_attn.{k,v,q,out}_proj.{weight,bias} + model.decoder.layers.{i}.encoder_attn_layer_norm.{weight,bias} + model.decoder.layers.{i}.fc1.{weight,bias} + model.decoder.layers.{i}.fc2.{weight,bias} + model.decoder.layers.{i}.final_layer_norm.{weight,bias} + model.decoder.layer_norm.{weight,bias} + proj_out.weight (output projection, may be tied to embed_tokens) +""" + +import os +from typing import Optional + +import numpy as np + +from .config import WhisperConfig + + +def _bfloat16_to_float32(data: bytes, shape: tuple) -> np.ndarray: + """Convert raw bfloat16 bytes to float32 numpy array. + + bfloat16 is the upper 16 bits of float32, so we just need to + shift left by 16 bits and view as float32. + + Args: + data: Raw bytes in bfloat16 format + shape: Target tensor shape + + Returns: + float32 numpy array + """ + # Read as uint16 + bf16 = np.frombuffer(data, dtype=np.uint16) + # Pad with zeros to create float32 (bfloat16 is upper 16 bits) + f32_int = bf16.astype(np.uint32) << 16 + # View as float32 + f32 = f32_int.view(np.float32) + return f32.reshape(shape) + + +def load_safetensors(file_path: str) -> dict[str, np.ndarray]: + """Load tensors from SafeTensors file. + + Args: + file_path: Path to .safetensors file + + Returns: + Dictionary mapping tensor names to numpy arrays (float32) + + Note: + bfloat16 tensors are automatically converted to float32 since + numpy doesn't natively support bfloat16. + """ + try: + from safetensors import safe_open + except ImportError as err: + raise ImportError( + "safetensors is required to load models. Install with: pip install safetensors" + ) from err + + tensors = {} + + # Check if any tensor is bfloat16 by trying to load + has_bfloat16 = False + with safe_open(file_path, framework="numpy") as f: + for key in f.keys(): + try: + tensors[key] = f.get_tensor(key) + except TypeError as e: + if "bfloat16" in str(e): + has_bfloat16 = True + break + raise + + # If bfloat16 detected, reload with raw bytes conversion + if has_bfloat16: + import json + import struct + + tensors = {} + + # Read safetensors header to get tensor info + with open(file_path, "rb") as f: + # First 8 bytes: header size (uint64 little-endian) + header_size = struct.unpack(" str: + """Download model from HuggingFace Hub. + + Args: + model_id: HuggingFace model ID (e.g., "kotoba-tech/kotoba-whisper-v2.0") + cache_dir: Optional cache directory + + Returns: + Path to downloaded model directory + """ + try: + from huggingface_hub import snapshot_download + except ImportError as err: + raise ImportError( + "huggingface_hub is required to download models. " + "Install with: pip install huggingface_hub" + ) from err + + model_path = snapshot_download( + repo_id=model_id, + cache_dir=cache_dir, + allow_patterns=["*.safetensors", "*.json", "tokenizer.*", "vocab.*", "merges.txt"], + ) + + return model_path + + +class WhisperWeights: + """Container for Whisper model weights. + + Organizes weights into encoder and decoder components with proper + tensor mapping from HuggingFace format. + """ + + def __init__(self, config: WhisperConfig): + self.config = config + + # Encoder weights + self.encoder_conv1_weight: Optional[np.ndarray] = None + self.encoder_conv1_bias: Optional[np.ndarray] = None + self.encoder_conv2_weight: Optional[np.ndarray] = None + self.encoder_conv2_bias: Optional[np.ndarray] = None + self.encoder_embed_positions: Optional[np.ndarray] = None + self.encoder_layers: list = [] + self.encoder_layer_norm_weight: Optional[np.ndarray] = None + self.encoder_layer_norm_bias: Optional[np.ndarray] = None + + # Decoder weights + self.decoder_embed_tokens: Optional[np.ndarray] = None + self.decoder_embed_positions: Optional[np.ndarray] = None + self.decoder_layers: list = [] + self.decoder_layer_norm_weight: Optional[np.ndarray] = None + self.decoder_layer_norm_bias: Optional[np.ndarray] = None + self.proj_out_weight: Optional[np.ndarray] = None + + @classmethod + def from_safetensors( + cls, model_path: str, config: Optional[WhisperConfig] = None + ) -> "WhisperWeights": + """Load weights from SafeTensors file or directory. + + Args: + model_path: Path to .safetensors file or model directory + config: Optional model config (will load from model_path if not provided) + + Returns: + WhisperWeights instance with loaded tensors + """ + # Resolve paths + if os.path.isdir(model_path): + safetensors_path = os.path.join(model_path, "model.safetensors") + config_path = os.path.join(model_path, "config.json") + else: + safetensors_path = model_path + config_path = os.path.join(os.path.dirname(model_path), "config.json") + + # Load config if not provided + if config is None: + if os.path.exists(config_path): + config = WhisperConfig.from_json(config_path) + else: + raise ValueError(f"Config not provided and config.json not found at {config_path}") + + # Load tensors + tensors = load_safetensors(safetensors_path) + + # Create weights instance and populate + weights = cls(config) + weights._load_encoder_weights(tensors) + weights._load_decoder_weights(tensors) + + return weights + + def _load_encoder_weights(self, tensors: dict[str, np.ndarray]) -> None: + """Load encoder weights from tensor dictionary.""" + # Conv layers + self.encoder_conv1_weight = tensors.get("model.encoder.conv1.weight") + self.encoder_conv1_bias = tensors.get("model.encoder.conv1.bias") + self.encoder_conv2_weight = tensors.get("model.encoder.conv2.weight") + self.encoder_conv2_bias = tensors.get("model.encoder.conv2.bias") + + # Positional embeddings + self.encoder_embed_positions = tensors.get("model.encoder.embed_positions.weight") + + # Final layer norm + self.encoder_layer_norm_weight = tensors.get("model.encoder.layer_norm.weight") + self.encoder_layer_norm_bias = tensors.get("model.encoder.layer_norm.bias") + + # Encoder layers + self.encoder_layers = [] + for i in range(self.config.encoder_layers): + layer = self._load_encoder_layer(tensors, i) + self.encoder_layers.append(layer) + + def _load_encoder_layer(self, tensors: dict[str, np.ndarray], layer_idx: int) -> dict: + """Load weights for a single encoder layer.""" + prefix = f"model.encoder.layers.{layer_idx}" + + return { + # Self attention + "self_attn_q_weight": tensors.get(f"{prefix}.self_attn.q_proj.weight"), + "self_attn_q_bias": tensors.get(f"{prefix}.self_attn.q_proj.bias"), + "self_attn_k_weight": tensors.get(f"{prefix}.self_attn.k_proj.weight"), + "self_attn_k_bias": tensors.get(f"{prefix}.self_attn.k_proj.bias"), + "self_attn_v_weight": tensors.get(f"{prefix}.self_attn.v_proj.weight"), + "self_attn_v_bias": tensors.get(f"{prefix}.self_attn.v_proj.bias"), + "self_attn_out_weight": tensors.get(f"{prefix}.self_attn.out_proj.weight"), + "self_attn_out_bias": tensors.get(f"{prefix}.self_attn.out_proj.bias"), + # Self attention layer norm + "self_attn_layer_norm_weight": tensors.get(f"{prefix}.self_attn_layer_norm.weight"), + "self_attn_layer_norm_bias": tensors.get(f"{prefix}.self_attn_layer_norm.bias"), + # FFN + "fc1_weight": tensors.get(f"{prefix}.fc1.weight"), + "fc1_bias": tensors.get(f"{prefix}.fc1.bias"), + "fc2_weight": tensors.get(f"{prefix}.fc2.weight"), + "fc2_bias": tensors.get(f"{prefix}.fc2.bias"), + # Final layer norm + "final_layer_norm_weight": tensors.get(f"{prefix}.final_layer_norm.weight"), + "final_layer_norm_bias": tensors.get(f"{prefix}.final_layer_norm.bias"), + } + + def _load_decoder_weights(self, tensors: dict[str, np.ndarray]) -> None: + """Load decoder weights from tensor dictionary.""" + # Embeddings + self.decoder_embed_tokens = tensors.get("model.decoder.embed_tokens.weight") + self.decoder_embed_positions = tensors.get("model.decoder.embed_positions.weight") + + # Final layer norm + self.decoder_layer_norm_weight = tensors.get("model.decoder.layer_norm.weight") + self.decoder_layer_norm_bias = tensors.get("model.decoder.layer_norm.bias") + + # Output projection (may be tied to embed_tokens) + self.proj_out_weight = tensors.get("proj_out.weight") + if self.proj_out_weight is None: + # Tied weights - use embed_tokens + self.proj_out_weight = self.decoder_embed_tokens + + # Decoder layers + self.decoder_layers = [] + for i in range(self.config.decoder_layers): + layer = self._load_decoder_layer(tensors, i) + self.decoder_layers.append(layer) + + def _load_decoder_layer(self, tensors: dict[str, np.ndarray], layer_idx: int) -> dict: + """Load weights for a single decoder layer.""" + prefix = f"model.decoder.layers.{layer_idx}" + + return { + # Self attention + "self_attn_q_weight": tensors.get(f"{prefix}.self_attn.q_proj.weight"), + "self_attn_q_bias": tensors.get(f"{prefix}.self_attn.q_proj.bias"), + "self_attn_k_weight": tensors.get(f"{prefix}.self_attn.k_proj.weight"), + "self_attn_k_bias": tensors.get(f"{prefix}.self_attn.k_proj.bias"), + "self_attn_v_weight": tensors.get(f"{prefix}.self_attn.v_proj.weight"), + "self_attn_v_bias": tensors.get(f"{prefix}.self_attn.v_proj.bias"), + "self_attn_out_weight": tensors.get(f"{prefix}.self_attn.out_proj.weight"), + "self_attn_out_bias": tensors.get(f"{prefix}.self_attn.out_proj.bias"), + # Self attention layer norm + "self_attn_layer_norm_weight": tensors.get(f"{prefix}.self_attn_layer_norm.weight"), + "self_attn_layer_norm_bias": tensors.get(f"{prefix}.self_attn_layer_norm.bias"), + # Cross attention (encoder_attn) + "cross_attn_q_weight": tensors.get(f"{prefix}.encoder_attn.q_proj.weight"), + "cross_attn_q_bias": tensors.get(f"{prefix}.encoder_attn.q_proj.bias"), + "cross_attn_k_weight": tensors.get(f"{prefix}.encoder_attn.k_proj.weight"), + "cross_attn_k_bias": tensors.get(f"{prefix}.encoder_attn.k_proj.bias"), + "cross_attn_v_weight": tensors.get(f"{prefix}.encoder_attn.v_proj.weight"), + "cross_attn_v_bias": tensors.get(f"{prefix}.encoder_attn.v_proj.bias"), + "cross_attn_out_weight": tensors.get(f"{prefix}.encoder_attn.out_proj.weight"), + "cross_attn_out_bias": tensors.get(f"{prefix}.encoder_attn.out_proj.bias"), + # Cross attention layer norm + "cross_attn_layer_norm_weight": tensors.get(f"{prefix}.encoder_attn_layer_norm.weight"), + "cross_attn_layer_norm_bias": tensors.get(f"{prefix}.encoder_attn_layer_norm.bias"), + # FFN + "fc1_weight": tensors.get(f"{prefix}.fc1.weight"), + "fc1_bias": tensors.get(f"{prefix}.fc1.bias"), + "fc2_weight": tensors.get(f"{prefix}.fc2.weight"), + "fc2_bias": tensors.get(f"{prefix}.fc2.bias"), + # Final layer norm + "final_layer_norm_weight": tensors.get(f"{prefix}.final_layer_norm.weight"), + "final_layer_norm_bias": tensors.get(f"{prefix}.final_layer_norm.bias"), + } + + def summary(self) -> str: + """Generate a summary of loaded weights.""" + lines = [ + "WhisperWeights Summary:", + f" Config: {self.config.d_model}d, {self.config.encoder_layers}enc, {self.config.decoder_layers}dec", + " Encoder:", + f" - Conv1: {self.encoder_conv1_weight.shape if self.encoder_conv1_weight is not None else 'None'}", + f" - Conv2: {self.encoder_conv2_weight.shape if self.encoder_conv2_weight is not None else 'None'}", + f" - Layers: {len(self.encoder_layers)}", + " Decoder:", + f" - Embed tokens: {self.decoder_embed_tokens.shape if self.decoder_embed_tokens is not None else 'None'}", + f" - Layers: {len(self.decoder_layers)}", + ] + return "\n".join(lines) + + +def load_whisper_model( + model_path_or_id: str, + cache_dir: Optional[str] = None, +) -> tuple[WhisperConfig, WhisperWeights]: + """Load Whisper model configuration and weights. + + Args: + model_path_or_id: Local path or HuggingFace model ID + cache_dir: Optional cache directory for downloads + + Returns: + Tuple of (WhisperConfig, WhisperWeights) + + Example: + >>> config, weights = load_whisper_model("kotoba-tech/kotoba-whisper-v2.0") + >>> print(config) + >>> print(weights.summary()) + """ + # Check if it's a local path + if os.path.exists(model_path_or_id): + model_path = model_path_or_id + else: + # Download from HuggingFace + model_path = download_model(model_path_or_id, cache_dir) + + # Load config + config = WhisperConfig.from_pretrained(model_path) + + # Load weights + weights = WhisperWeights.from_safetensors(model_path, config) + + return config, weights + + +__all__ = [ + "load_safetensors", + "download_model", + "WhisperWeights", + "load_whisper_model", +] diff --git a/src/pygpukit/asr/whisper/model.py b/src/pygpukit/asr/whisper/model.py new file mode 100644 index 0000000..16573e6 --- /dev/null +++ b/src/pygpukit/asr/whisper/model.py @@ -0,0 +1,481 @@ +"""Whisper model for speech recognition. + +Provides a unified interface for Whisper transcription with support for: +- Single-file transcription +- Streaming/chunked inference for long audio +- Multiple output formats (text, segments with timestamps) +""" + +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass, field + +import numpy as np + +from ...core import GPUArray, from_numpy +from ...ops.audio import AudioBuffer +from ..preprocessing import ( + WHISPER_CHUNK_LENGTH, + WHISPER_HOP_LENGTH, + WHISPER_SAMPLE_RATE, + normalize_mel, + pad_or_trim, +) +from .config import WhisperConfig +from .decoder import WhisperDecoder, create_decoder +from .encoder import WhisperEncoder, create_encoder +from .loader import load_whisper_model + + +@dataclass +class TranscriptionSegment: + """A single transcription segment with timing information.""" + + text: str + start: float # seconds + end: float # seconds + tokens: list[int] = field(default_factory=list) + + +@dataclass +class TranscriptionResult: + """Complete transcription result.""" + + text: str + segments: list[TranscriptionSegment] = field(default_factory=list) + language: str | None = None + + +class WhisperTokenizer: + """Simple tokenizer wrapper for Whisper models. + + Uses the HuggingFace tokenizers library if available, + otherwise provides a basic fallback. + """ + + def __init__(self, model_path: str): + self.model_path = model_path + self._tokenizer = None + self._load_tokenizer() + + def _load_tokenizer(self) -> None: + """Load tokenizer from model path.""" + import os + + try: + from tokenizers import Tokenizer + + tokenizer_path = os.path.join(self.model_path, "tokenizer.json") + if os.path.exists(tokenizer_path): + self._tokenizer = Tokenizer.from_file(tokenizer_path) + except ImportError: + pass + + def encode(self, text: str) -> list[int]: + """Encode text to token IDs.""" + if self._tokenizer is not None: + return self._tokenizer.encode(text).ids + raise RuntimeError("Tokenizer not available") + + def decode(self, token_ids: list[int], skip_special_tokens: bool = True) -> str: + """Decode token IDs to text.""" + if self._tokenizer is not None: + return self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + raise RuntimeError("Tokenizer not available") + + +class WhisperModel: + """Whisper model for speech recognition. + + Example: + >>> model = WhisperModel.from_pretrained("kotoba-tech/kotoba-whisper-v2.0") + >>> result = model.transcribe("audio.wav", language="ja") + >>> print(result.text) + + # Streaming mode for long audio + >>> for segment in model.transcribe_streaming(audio_array, language="ja"): + ... print(f"[{segment.start:.2f} - {segment.end:.2f}] {segment.text}") + """ + + def __init__( + self, + config: WhisperConfig, + encoder: WhisperEncoder, + decoder: WhisperDecoder, + tokenizer: WhisperTokenizer | None = None, + ): + self.config = config + self.encoder = encoder + self.decoder = decoder + self.tokenizer = tokenizer + + @classmethod + def from_pretrained( + cls, + model_path_or_id: str, + cache_dir: str | None = None, + ) -> WhisperModel: + """Load a pretrained Whisper model. + + Args: + model_path_or_id: Local path or HuggingFace model ID + cache_dir: Optional cache directory for downloads + + Returns: + Initialized WhisperModel + + Example: + >>> model = WhisperModel.from_pretrained("kotoba-tech/kotoba-whisper-v2.0") + """ + import os + + # Load config and weights + config, weights = load_whisper_model(model_path_or_id, cache_dir) + + # Create encoder and decoder + encoder = create_encoder(config, weights) + decoder = create_decoder(config, weights) + + # Load tokenizer + tokenizer = None + if os.path.exists(model_path_or_id): + tokenizer = WhisperTokenizer(model_path_or_id) + else: + # Try to get cached path + try: + from huggingface_hub import snapshot_download + + model_path = snapshot_download( + repo_id=model_path_or_id, + cache_dir=cache_dir, + allow_patterns=["tokenizer.*"], + ) + tokenizer = WhisperTokenizer(model_path) + except Exception: + pass + + return cls(config, encoder, decoder, tokenizer) + + def transcribe( + self, + audio: np.ndarray | str, + sample_rate: int | None = None, + language: str | None = None, + max_length: int = 448, + temperature: float = 0.0, + **kwargs, + ) -> TranscriptionResult: + """Transcribe audio to text. + + Args: + audio: Audio waveform (numpy array) or path to audio file + sample_rate: Sample rate of input audio (required if not 16kHz) + language: Optional language code (e.g., "ja", "en") + max_length: Maximum number of tokens to generate + temperature: Sampling temperature (0 for greedy) + + Returns: + TranscriptionResult with text and optional segments + """ + # Load audio if path + if isinstance(audio, str): + audio = self._load_audio(audio) + + # Resample to 16kHz if needed + if sample_rate is not None and sample_rate != WHISPER_SAMPLE_RATE: + audio_gpu = from_numpy(audio.astype(np.float32)) + audio_buf = AudioBuffer(data=audio_gpu, sample_rate=sample_rate, channels=1) + audio_buf = audio_buf.resample(WHISPER_SAMPLE_RATE) + audio = audio_buf.data.to_numpy() + + # Preprocess to mel spectrogram + mel = self._preprocess_audio(audio) + + # Encode audio + encoder_output = self.encoder(mel) + + # Decode to tokens + tokens = self.decoder.generate( + encoder_output, + max_length=max_length, + temperature=temperature, + top_k=None if temperature == 0.0 else 50, + ) + + # Decode tokens to text + text = self._decode_tokens(tokens) + + return TranscriptionResult( + text=text, + segments=[ + TranscriptionSegment( + text=text, + start=0.0, + end=len(audio) / WHISPER_SAMPLE_RATE, + tokens=tokens, + ) + ], + language=language, + ) + + def transcribe_streaming( + self, + audio: np.ndarray, + language: str | None = None, + chunk_length: float = WHISPER_CHUNK_LENGTH, + overlap: float = 0.0, + max_length: int = 448, + temperature: float = 0.0, + **kwargs, + ) -> Iterator[TranscriptionSegment]: + """Transcribe long audio in chunks, yielding segments as they're processed. + + Args: + audio: Audio waveform at 16kHz + language: Optional language code + chunk_length: Length of each chunk in seconds (default: 30s) + overlap: Overlap between chunks in seconds + max_length: Maximum tokens per chunk + temperature: Sampling temperature + + Yields: + TranscriptionSegment for each processed chunk + """ + samples_per_chunk = int(chunk_length * WHISPER_SAMPLE_RATE) + overlap_samples = int(overlap * WHISPER_SAMPLE_RATE) + stride = samples_per_chunk - overlap_samples + + # Process audio in chunks + start_sample = 0 + while start_sample < len(audio): + end_sample = min(start_sample + samples_per_chunk, len(audio)) + chunk = audio[start_sample:end_sample] + + # Process chunk + mel = self._preprocess_audio(chunk) + encoder_output = self.encoder(mel) + + tokens = self.decoder.generate( + encoder_output, + max_length=max_length, + temperature=temperature, + top_k=None if temperature == 0.0 else 50, + ) + + text = self._decode_tokens(tokens) + + # Calculate timing + start_time = start_sample / WHISPER_SAMPLE_RATE + end_time = end_sample / WHISPER_SAMPLE_RATE + + yield TranscriptionSegment( + text=text, + start=start_time, + end=end_time, + tokens=tokens, + ) + + start_sample += stride + + def _load_audio(self, path: str) -> np.ndarray: + """Load audio file and resample to 16kHz mono. + + Args: + path: Path to audio file + + Returns: + Audio waveform at 16kHz + """ + try: + import soundfile as sf + + audio, sr = sf.read(path) + + # Convert to mono if stereo + if audio.ndim > 1: + audio = audio.mean(axis=1) + + # Resample if needed + if sr != WHISPER_SAMPLE_RATE: + try: + import resampy + + audio = resampy.resample(audio, sr, WHISPER_SAMPLE_RATE) + except ImportError as err: + raise RuntimeError( + f"Audio sample rate is {sr}Hz but Whisper requires {WHISPER_SAMPLE_RATE}Hz. " + "Install resampy to enable automatic resampling: pip install resampy" + ) from err + + return audio.astype(np.float32) + + except ImportError as err: + raise ImportError( + "soundfile is required to load audio files. Install with: pip install soundfile" + ) from err + + def _preprocess_audio(self, audio: np.ndarray) -> GPUArray: + """Convert audio to mel spectrogram. + + Args: + audio: Audio waveform at 16kHz + + Returns: + Mel spectrogram [1, n_mels, n_frames] + """ + # Pad or trim to 30 seconds + audio_gpu = pad_or_trim(audio) + audio_np = audio_gpu.to_numpy() + + # Compute mel spectrogram using numpy + mel = self._compute_mel_spectrogram(audio_np) + + # Normalize (accepts numpy directly) + mel = normalize_mel(mel) + + # Add batch dimension + mel_np = mel.to_numpy() + return from_numpy(mel_np.reshape(1, *mel_np.shape)) + + def _compute_mel_spectrogram(self, audio: np.ndarray) -> np.ndarray: + """Compute log-mel spectrogram. + + Args: + audio: Audio waveform at 16kHz + + Returns: + Mel spectrogram [n_mels, n_frames] + """ + from ..preprocessing import WHISPER_N_FFT + + # Use librosa if available, otherwise numpy fallback + try: + import librosa + + mel = librosa.feature.melspectrogram( + y=audio, + sr=WHISPER_SAMPLE_RATE, + n_fft=WHISPER_N_FFT, + hop_length=WHISPER_HOP_LENGTH, + n_mels=self.config.num_mel_bins, + fmin=0, + fmax=8000, + ) + # Convert to log scale + mel = np.log10(np.clip(mel, a_min=1e-10, a_max=None)) + + except ImportError: + # Numpy fallback (basic STFT + mel filterbank) + mel = self._compute_mel_numpy(audio) + + return mel.astype(np.float32) + + def _compute_mel_numpy(self, audio: np.ndarray) -> np.ndarray: + """Compute mel spectrogram using numpy (fallback). + + Args: + audio: Audio waveform + + Returns: + Mel spectrogram + """ + from ..preprocessing import WHISPER_N_FFT + + n_fft = WHISPER_N_FFT + hop_length = WHISPER_HOP_LENGTH + n_mels = self.config.num_mel_bins + + # Pad audio + audio = np.pad(audio, (n_fft // 2, n_fft // 2), mode="reflect") + + # STFT + n_frames = 1 + (len(audio) - n_fft) // hop_length + stft = np.zeros((n_fft // 2 + 1, n_frames), dtype=np.complex64) + + window = np.hanning(n_fft) + for i in range(n_frames): + start = i * hop_length + frame = audio[start : start + n_fft] * window + stft[:, i] = np.fft.rfft(frame) + + # Power spectrum + power = np.abs(stft) ** 2 + + # Mel filterbank + mel_basis = self._create_mel_filterbank(n_mels, n_fft) + mel = mel_basis @ power + + # Log scale + mel = np.log10(np.clip(mel, a_min=1e-10, a_max=None)) + + return mel + + def _create_mel_filterbank(self, n_mels: int, n_fft: int) -> np.ndarray: + """Create mel filterbank matrix. + + Args: + n_mels: Number of mel bands + n_fft: FFT size + + Returns: + Mel filterbank [n_mels, n_fft//2+1] + """ + fmin = 0.0 + fmax = WHISPER_SAMPLE_RATE / 2 + + # Mel scale conversion + def hz_to_mel(hz): + return 2595 * np.log10(1 + hz / 700) + + def mel_to_hz(mel): + return 700 * (10 ** (mel / 2595) - 1) + + # Mel points + mel_min = hz_to_mel(fmin) + mel_max = hz_to_mel(fmax) + mel_points = np.linspace(mel_min, mel_max, n_mels + 2) + hz_points = mel_to_hz(mel_points) + + # FFT bins + bin_points = np.floor((n_fft + 1) * hz_points / WHISPER_SAMPLE_RATE).astype(int) + + # Create filterbank + filterbank = np.zeros((n_mels, n_fft // 2 + 1)) + for i in range(n_mels): + left = bin_points[i] + center = bin_points[i + 1] + right = bin_points[i + 2] + + # Rising edge + for j in range(left, center): + filterbank[i, j] = (j - left) / (center - left) + + # Falling edge + for j in range(center, right): + filterbank[i, j] = (right - j) / (right - center) + + return filterbank + + def _decode_tokens(self, tokens: list[int]) -> str: + """Decode token IDs to text. + + Args: + tokens: List of token IDs + + Returns: + Decoded text string + """ + if self.tokenizer is not None: + return self.tokenizer.decode(tokens, skip_special_tokens=True) + + # Fallback: just return token IDs as string + return f"" + + +__all__ = [ + "WhisperModel", + "WhisperTokenizer", + "TranscriptionResult", + "TranscriptionSegment", +] diff --git a/src/pygpukit/core/array.py b/src/pygpukit/core/array.py index b2c8b40..6fbfa8f 100644 --- a/src/pygpukit/core/array.py +++ b/src/pygpukit/core/array.py @@ -67,9 +67,11 @@ def _wrap_native(cls, native_array: Any) -> GPUArray: float16, float32, float64, + int8, int16, int32, int64, + uint8, ) native = get_native_module() @@ -90,6 +92,10 @@ def _wrap_native(cls, native_array: Any) -> GPUArray: dtype = int32 elif native_dtype == native.DataType.Int16: dtype = int16 + elif native_dtype == native.DataType.Int8: + dtype = int8 + elif native_dtype == native.DataType.UInt8: + dtype = uint8 else: raise ValueError(f"Unknown native dtype: {native_dtype}") @@ -247,30 +253,94 @@ def __del__(self) -> None: # Arithmetic operators # ======================================================================== - def __add__(self, other: GPUArray) -> GPUArray: - """Element-wise addition.""" + def __add__(self, other: GPUArray | int | float) -> GPUArray: + """Element-wise addition. + + Supports both GPUArray and scalar (int/float) operands. + Broadcasting is supported for compatible shapes. + """ + if isinstance(other, (int, float)): + return self._scalar_op(other, lambda a, b: a + b) + + # Check if broadcasting is needed + if self.shape != other.shape: + # Use numpy broadcasting + from pygpukit.core.factory import from_numpy + + a_np = self.to_numpy() + b_np = other.to_numpy() + result = a_np + b_np + return from_numpy(result.astype(a_np.dtype)) + from pygpukit.ops.basic import add return add(self, other) - def __sub__(self, other: GPUArray) -> GPUArray: - """Element-wise subtraction.""" + def __radd__(self, other: int | float) -> GPUArray: + """Right-hand addition for scalar + GPUArray.""" + return self._scalar_op(other, lambda a, b: b + a) + + def __sub__(self, other: GPUArray | int | float) -> GPUArray: + """Element-wise subtraction. + + Supports both GPUArray and scalar (int/float) operands. + """ + if isinstance(other, (int, float)): + return self._scalar_op(other, lambda a, b: a - b) from pygpukit.ops.basic import sub return sub(self, other) - def __mul__(self, other: GPUArray) -> GPUArray: - """Element-wise multiplication.""" + def __rsub__(self, other: int | float) -> GPUArray: + """Right-hand subtraction for scalar - GPUArray.""" + return self._scalar_op(other, lambda a, b: b - a) + + def __mul__(self, other: GPUArray | int | float) -> GPUArray: + """Element-wise multiplication. + + Supports both GPUArray and scalar (int/float) operands. + """ + if isinstance(other, (int, float)): + return self._scalar_op(other, lambda a, b: a * b) from pygpukit.ops.basic import mul return mul(self, other) - def __truediv__(self, other: GPUArray) -> GPUArray: - """Element-wise division.""" + def __rmul__(self, other: int | float) -> GPUArray: + """Right-hand multiplication for scalar * GPUArray.""" + return self._scalar_op(other, lambda a, b: b * a) + + def __truediv__(self, other: GPUArray | int | float) -> GPUArray: + """Element-wise division. + + Supports both GPUArray and scalar (int/float) operands. + """ + if isinstance(other, (int, float)): + return self._scalar_op(other, lambda a, b: a / b) from pygpukit.ops.basic import div return div(self, other) + def __rtruediv__(self, other: int | float) -> GPUArray: + """Right-hand division for scalar / GPUArray.""" + return self._scalar_op(other, lambda a, b: b / a) + + def _scalar_op(self, scalar: int | float, op) -> GPUArray: + """Apply a scalar operation using NumPy. + + Args: + scalar: The scalar operand. + op: A callable that takes (array, scalar) and returns the result. + + Returns: + A new GPUArray with the result. + """ + from pygpukit.core.factory import from_numpy + + np_data = self.to_numpy() + result = op(np_data, scalar) + return from_numpy(result.astype(np_data.dtype)) + def __matmul__(self, other: GPUArray) -> GPUArray: """Matrix multiplication.""" from pygpukit.ops.basic import matmul @@ -377,8 +447,10 @@ def narrow(self, offset: int, length: int) -> GPUArray: # Call native narrow view_native = native.GPUArray.narrow(src_native, offset_elements, new_shape) - # Wrap the view - return GPUArray._wrap_native(view_native) + # Wrap the view and keep reference to source to prevent memory from being freed + view_arr = GPUArray._wrap_native(view_native) + view_arr._source_ref = self + return view_arr def view(self, new_shape: tuple[int, ...]) -> GPUArray: """Create a zero-copy view with a different shape (same total elements). @@ -423,8 +495,10 @@ def view(self, new_shape: tuple[int, ...]) -> GPUArray: # Use narrow with offset=0 to create view with new shape view_native = native.GPUArray.narrow(src_native, 0, list(new_shape)) - # Wrap the view - return GPUArray._wrap_native(view_native) + # Wrap the view and keep reference to source to prevent memory from being freed + view_arr = GPUArray._wrap_native(view_native) + view_arr._source_ref = self # Keep source alive while view exists + return view_arr def slice_rows(self, num_rows: int) -> GPUArray: """Create a zero-copy view of the first N rows (batch dimension). @@ -468,4 +542,200 @@ def slice_rows(self, num_rows: int) -> GPUArray: # Use narrow with offset=0 to get first num_rows rows view_native = native.GPUArray.narrow(src_native, 0, new_shape) - return GPUArray._wrap_native(view_native) + # Keep reference to source to prevent memory from being freed + view_arr = GPUArray._wrap_native(view_native) + view_arr._source_ref = self + return view_arr + + def transpose(self, *axes: int) -> GPUArray: + """Transpose the array by permuting its axes. + + Uses native GPU kernels when available for common patterns: + - 2D (1,0): Native matmul.transpose() + - 3D (1,0,2): Native tensor.transpose_3d_021() + - 3D (0,2,1): Native tensor.transpose_3d_012() + - 4D (0,2,1,3): Native tensor.transpose_4d_0213() + - 4D (0,1,3,2): Native tensor.transpose_4d_0132() + - Other patterns: CPU fallback + + Args: + *axes: The new order of axes. If not provided, reverses all axes. + For a 3D array, transpose(0, 2, 1) swaps the last two axes. + + Returns: + A new GPUArray with transposed data. + + Example: + # Transpose 2D matrix + a = from_numpy(np.array([[1, 2], [3, 4]])) + b = a.transpose() # or a.T + + # Permute 3D tensor axes + x = from_numpy(np.zeros((2, 3, 4))) + y = x.transpose(0, 2, 1) # shape (2, 4, 3) + """ + from pygpukit.core.backend import NativeBackend, get_backend + from pygpukit.core.factory import from_numpy + + # Normalize axes + if len(axes) == 0: + # Reverse all axes + axes = tuple(range(self.ndim - 1, -1, -1)) + + # Check if we can use native implementations + backend = get_backend() + dtype_str = str(self.dtype) + use_native = ( + isinstance(backend, NativeBackend) + and backend.is_available() + and dtype_str in ("float32", "float16", "bfloat16") + ) + + if use_native: + # 2D transpose: (1, 0) + if self.ndim == 2 and axes == (1, 0): + from pygpukit.ops.matmul import transpose as matmul_transpose + + return matmul_transpose(self) + + # 3D transpose (1, 0, 2): [d0, d1, d2] -> [d1, d0, d2] + if self.ndim == 3 and axes == (1, 0, 2): + from pygpukit.ops.tensor import transpose_3d_021 + + result = transpose_3d_021(self) + return result if result is not None else self + + # 3D transpose (0, 2, 1): [d0, d1, d2] -> [d0, d2, d1] + if self.ndim == 3 and axes == (0, 2, 1): + from pygpukit.ops.tensor import transpose_3d_012 + + result = transpose_3d_012(self) + return result if result is not None else self + + # 4D transpose (0, 2, 1, 3): [d0, d1, d2, d3] -> [d0, d2, d1, d3] + if self.ndim == 4 and axes == (0, 2, 1, 3): + from pygpukit.ops.tensor import transpose_4d_0213 + + result = transpose_4d_0213(self) + return result if result is not None else self + + # 4D transpose (0, 1, 3, 2): [d0, d1, d2, d3] -> [d0, d1, d3, d2] + if self.ndim == 4 and axes == (0, 1, 3, 2): + from pygpukit.ops.tensor import transpose_4d_0132 + + result = transpose_4d_0132(self) + return result if result is not None else self + + # CPU fallback for unsupported patterns + np_data = self.to_numpy() + result = np_data.transpose(*axes) + return from_numpy(result.copy()) + + @property + def T(self) -> GPUArray: + """Return transposed array (reverses all axes).""" + return self.transpose() + + def reshape(self, *shape: int) -> GPUArray: + """Reshape the array to a new shape. + + Args: + *shape: The new shape. Can be passed as separate args or as a tuple. + One dimension can be -1 to infer from the total size. + + Returns: + A new GPUArray with the specified shape. + + Example: + x = from_numpy(np.zeros((2, 3, 4))) + y = x.reshape(6, 4) # or x.reshape((6, 4)) + z = x.reshape(-1, 4) # infer first dimension + """ + from pygpukit.core.backend import NativeBackend, get_backend + + # Handle both reshape(2, 3) and reshape((2, 3)) + if len(shape) == 1 and isinstance(shape[0], (tuple, list)): + shape = tuple(shape[0]) + + # Handle -1 dimension inference + shape = list(shape) + total_size = 1 + for dim in self.shape: + total_size *= dim + + neg_idx = -1 + known_size = 1 + for i, dim in enumerate(shape): + if dim == -1: + if neg_idx >= 0: + raise ValueError("reshape: only one dimension can be -1") + neg_idx = i + else: + known_size *= dim + + if neg_idx >= 0: + if total_size % known_size != 0: + raise ValueError( + f"reshape: cannot infer dimension, total size {total_size} " + f"not divisible by {known_size}" + ) + shape[neg_idx] = total_size // known_size + + shape = tuple(shape) + + # Verify total size + output_size = 1 + for dim in shape: + output_size *= dim + if output_size != total_size: + raise ValueError( + f"reshape: cannot reshape array of size {total_size} into shape {shape}" + ) + + # Use native reshape_copy if available (keeps data on GPU) + backend = get_backend() + if isinstance(backend, NativeBackend) and backend.is_available(): + dtype_str = str(self.dtype) + if dtype_str in ("float32", "float16", "bfloat16"): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + input_native = self._get_native() + c_native = native.reshape_copy(input_native, list(shape)) + return GPUArray._wrap_native(c_native) + + # CPU fallback + from pygpukit.core.factory import from_numpy + + np_data = self.to_numpy() + result = np_data.reshape(shape) + return from_numpy(result.copy()) + + def __getitem__(self, key) -> GPUArray: + """Index or slice the array. + + Supports NumPy-style indexing including: + - Integer indexing: arr[0] + - Slicing: arr[:10], arr[1:5], arr[::2] + - Multi-dimensional: arr[0, :, 1:3] + + Args: + key: Index, slice, or tuple of indices/slices. + + Returns: + A new GPUArray containing the selected elements. + + Example: + x = from_numpy(np.arange(100).reshape(10, 10)) + row = x[0] # First row + col = x[:, 0] # First column + sub = x[:5, :5] # 5x5 subarray + """ + from pygpukit.core.factory import from_numpy + + np_data = self.to_numpy() + result = np_data[key] + # Handle scalar result + if not isinstance(result, np.ndarray): + result = np.array(result) + return from_numpy(result.copy()) diff --git a/src/pygpukit/llm/layers.py b/src/pygpukit/llm/layers.py index be750e6..da9b82d 100644 --- a/src/pygpukit/llm/layers.py +++ b/src/pygpukit/llm/layers.py @@ -26,6 +26,7 @@ concat_axis0, copy_to, gelu, + gemv_bf16, kv_cache_prefill_gqa, kv_cache_update_gqa, layernorm, @@ -58,8 +59,14 @@ class Linear: """Linear layer: y = xW^T + b Weights are stored as [out_features, in_features] (PyTorch convention). + + For M=1 (single token decode), uses custom GEMV kernel which is 4-6x faster + than cuBLASLt matmul. Automatically falls back to matmul for batch > 1. """ + # Class-level flag to enable/disable GEMV optimization + _use_gemv: bool = True + def __init__(self, weight: GPUArray, bias: GPUArray | None = None): if weight.ndim != 2: raise ValueError(f"weight must be 2D, got {weight.ndim}D") @@ -85,7 +92,29 @@ def __call__(self, x: GPUArray, *, out: GPUArray | None = None) -> GPUArray: if self._weight_t is None: self._weight_t = transpose(self.weight) - y = matmul(x, self._weight_t, out=out) + # Use GEMV for M=1 with BF16 (1.3-2.4x faster than matmul) + # Skip GEMV when out is provided (CUDA Graph mode) - GEMV allocates internally + use_gemv = ( + Linear._use_gemv + and x.shape[0] == 1 + and x.dtype == dt_bfloat16 + and out is None # GEMV allocates, not compatible with CUDA Graph + ) + + if use_gemv: + # GEMV path: zero-copy view to 1D, call gemv_bf16, view back to 2D + x_1d = x.view((self.in_features,)) + y_1d = gemv_bf16(x_1d, self._weight_t) + + if out is not None: + # Copy to output buffer + copy_to(y_1d.view((1, self.out_features)), out) + y = out + else: + y = y_1d.view((1, self.out_features)) + else: + # Standard matmul path + y = matmul(x, self._weight_t, out=out) if self.bias is not None: bias_add_inplace(y, self.bias) diff --git a/src/pygpukit/ops/__init__.py b/src/pygpukit/ops/__init__.py index c7f29c1..7e22fae 100644 --- a/src/pygpukit/ops/__init__.py +++ b/src/pygpukit/ops/__init__.py @@ -16,6 +16,8 @@ # Elementwise add, add_inplace, + # Matmul + batched_matmul, # Neural Network bias_add_inplace, # Tensor @@ -32,7 +34,17 @@ embedding_lookup_ptr, # Unary exp, + fp8_available, + fp8_fp8_get_scale_sizes, + fp8_fp8_sm120_available, + fp8_sm90_available, + fp8_sm100_available, + fp8_sm120_available, gelu, + # GEMV + gemv_bf16, + gemv_nvf4_available, + gemv_nvf4_bf16, kv_cache_prefill, kv_cache_prefill_gqa, kv_cache_update, @@ -43,11 +55,21 @@ linear_bias_gelu, log, matmul, + matmul_fp8, + matmul_fp8_fp8_blockwise_sm120, + matmul_fp8_fp8_sm120, + matmul_fp8_sm90, + matmul_fp8_sm100, + matmul_fp8_sm120, + matmul_nvf4_bf16_sm120, # Reduction max, mean, mul, mul_inplace, + nvf4_bf16_sm120_available, + nvf4_get_sizes, + quantize_bf16_to_nvf4, relu, repeat_interleave_axis1, reshape_copy, @@ -73,6 +95,7 @@ sum, transpose, transpose_3d_021, + transpose_4d_0213, ) __all__ = [ @@ -95,8 +118,29 @@ "softmax", # Matmul "matmul", + "batched_matmul", "transpose", "linear_bias_gelu", + "matmul_fp8", + "matmul_fp8_fp8_blockwise_sm120", + "matmul_fp8_fp8_sm120", + "matmul_fp8_sm90", + "matmul_fp8_sm100", + "matmul_fp8_sm120", + "matmul_nvf4_bf16_sm120", + "fp8_available", + "fp8_fp8_get_scale_sizes", + "fp8_fp8_sm120_available", + "fp8_sm90_available", + "fp8_sm100_available", + "fp8_sm120_available", + "nvf4_bf16_sm120_available", + # GEMV + "gemv_bf16", + "gemv_nvf4_bf16", + "gemv_nvf4_available", + "nvf4_get_sizes", + "quantize_bf16_to_nvf4", # Neural Network "gelu", "silu", @@ -131,6 +175,7 @@ "concat_axis0", "repeat_interleave_axis1", "transpose_3d_021", + "transpose_4d_0213", "reshape_copy", "cast_f32_to_bf16", "cast_f32_to_f16", diff --git a/src/pygpukit/ops/basic.py b/src/pygpukit/ops/basic.py index 07d6b1a..395070b 100644 --- a/src/pygpukit/ops/basic.py +++ b/src/pygpukit/ops/basic.py @@ -25,11 +25,13 @@ from pygpukit.ops.elementwise import ( add, add_inplace, + clamp, copy_to, div, mul, mul_inplace, sub, + where, ) # Re-export embedding operations @@ -46,8 +48,29 @@ # Re-export matmul operations from pygpukit.ops.matmul import ( + batched_matmul, + fp8_available, + fp8_fp8_get_scale_sizes, + fp8_fp8_sm120_available, + fp8_sm90_available, + fp8_sm100_available, + fp8_sm120_available, + # GEMV operations + gemv_bf16, + gemv_nvf4_available, + gemv_nvf4_bf16, linear_bias_gelu, matmul, + matmul_fp8, + matmul_fp8_fp8_blockwise_sm120, + matmul_fp8_fp8_sm120, + matmul_fp8_sm90, + matmul_fp8_sm100, + matmul_fp8_sm120, + matmul_nvf4_bf16_sm120, + nvf4_bf16_sm120_available, + nvf4_get_sizes, + quantize_bf16_to_nvf4, transpose, ) @@ -62,17 +85,22 @@ sdpa_causal, sdpa_causal_fixed_cache, sdpa_causal_fixed_cache_ptr, + sigmoid, silu, slice_rows_range_ptr, split_qkv_batch, + tanh, ) # Re-export reduction operations from pygpukit.ops.reduction import ( + argmax, max, mean, + min, softmax, sum, + sum_axis, ) # Re-export sampling operations @@ -96,13 +124,20 @@ repeat_interleave_axis1, reshape_copy, transpose_3d_021, + transpose_4d_0213, ) # Re-export unary operations from pygpukit.ops.unary import ( + abs, + cos, exp, log, + neg, relu, + rsqrt, + sin, + sqrt, ) __all__ = [ @@ -118,22 +153,56 @@ "add_inplace", "mul_inplace", "copy_to", + "clamp", + "where", # Unary + "abs", + "cos", "exp", "log", + "neg", "relu", + "rsqrt", + "sin", + "sqrt", # Reduction - "sum", - "mean", + "argmax", "max", + "mean", + "min", "softmax", + "sum", + "sum_axis", # Matmul "matmul", + "batched_matmul", "transpose", "linear_bias_gelu", + "matmul_fp8", + "matmul_fp8_sm90", + "matmul_fp8_sm100", + "matmul_fp8_sm120", + "matmul_nvf4_bf16_sm120", + "fp8_available", + "fp8_fp8_sm120_available", + "fp8_fp8_get_scale_sizes", + "fp8_sm90_available", + "fp8_sm100_available", + "fp8_sm120_available", + "matmul_fp8_fp8_blockwise_sm120", + "matmul_fp8_fp8_sm120", + "nvf4_bf16_sm120_available", + # GEMV + "gemv_bf16", + "gemv_nvf4_bf16", + "gemv_nvf4_available", + "nvf4_get_sizes", + "quantize_bf16_to_nvf4", # Neural Network "gelu", + "sigmoid", "silu", + "tanh", "layernorm", "rmsnorm", "bias_add_inplace", @@ -165,6 +234,7 @@ "concat_axis0", "repeat_interleave_axis1", "transpose_3d_021", + "transpose_4d_0213", "reshape_copy", "cast_f32_to_bf16", "cast_f32_to_f16", diff --git a/src/pygpukit/ops/elementwise.py b/src/pygpukit/ops/elementwise.py index ac38b7b..255afa0 100644 --- a/src/pygpukit/ops/elementwise.py +++ b/src/pygpukit/ops/elementwise.py @@ -241,3 +241,60 @@ def copy_to(src: GPUArray, dst: GPUArray) -> None: src_native = src._get_native() dst_native = dst._get_native() native.copy_to(src_native, dst_native) + + +def clamp(a: GPUArray, min_val: float, max_val: float) -> GPUArray: + """Element-wise clamp: clamp(x, min, max). + + Args: + a: Input array (float types). + min_val: Minimum value. + max_val: Maximum value. + + Returns: + A new GPUArray with values clamped to [min_val, max_val]. + """ + import numpy as np + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return GPUArray._wrap_native(native.clamp(a._get_native(), min_val, max_val)) + else: + a_np = a.to_numpy() + return from_numpy(np.clip(a_np, min_val, max_val)) + + +def where(cond: GPUArray, a: GPUArray, b: GPUArray) -> GPUArray: + """Conditional select: where(cond, a, b) = cond ? a : b. + + Args: + cond: Boolean condition array (uint8 or int8, 0=False, nonzero=True). + a: Values to use where condition is True. + b: Values to use where condition is False. + + Returns: + A new GPUArray with values selected from a or b based on cond. + """ + import numpy as np + + _validate_same_shape(a, b, "where") + _validate_same_dtype(a, b, "where") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return GPUArray._wrap_native( + native.where(cond._get_native(), a._get_native(), b._get_native()) + ) + else: + cond_np: np.ndarray = cond.to_numpy().astype(bool) + a_np = a.to_numpy() + b_np = b.to_numpy() + return from_numpy(np.where(cond_np, a_np, b_np)) diff --git a/src/pygpukit/ops/matmul.py b/src/pygpukit/ops/matmul.py index 9e235cb..c15a523 100644 --- a/src/pygpukit/ops/matmul.py +++ b/src/pygpukit/ops/matmul.py @@ -5,6 +5,8 @@ from __future__ import annotations +import warnings + import numpy as np from pygpukit.core.array import GPUArray @@ -281,3 +283,1266 @@ def _linear_bias_gelu_native( bias_native = bias._get_native() c_native = native.linear_bias_gelu(input_native, weight_native, bias_native) return GPUArray._wrap_native(c_native) + + +def batched_matmul( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """Batched matrix multiplication for 3D and 4D tensors. + + Supports: + - 3D: [batch, M, K] @ [batch, K, N] -> [batch, M, N] + - 4D: [batch1, batch2, M, K] @ [batch1, batch2, K, N] -> [batch1, batch2, M, N] + + Args: + a: First input array (3D or 4D). + b: Second input array (3D or 4D). + out: Optional output array. If provided, result is written in-place. + + Returns: + The result GPUArray with shape [..., M, N]. + + Raises: + ValueError: If arrays are not 3D/4D or dimensions don't match. + """ + if a.ndim not in (3, 4): + raise ValueError(f"batched_matmul requires 3D or 4D arrays, got {a.ndim}D") + if b.ndim not in (3, 4): + raise ValueError(f"batched_matmul requires 3D or 4D arrays, got {b.ndim}D") + if a.ndim != b.ndim: + raise ValueError(f"batched_matmul requires same ndim, got {a.ndim}D and {b.ndim}D") + + _validate_same_dtype(a, b, "batched_matmul") + + # Extract dimensions + if a.ndim == 3: + batch = a.shape[0] + M, K = a.shape[1], a.shape[2] + K2, N = b.shape[1], b.shape[2] + if b.shape[0] != batch: + raise ValueError(f"Batch dimension mismatch: {a.shape[0]} vs {b.shape[0]}") + if K != K2: + raise ValueError(f"Inner dimension mismatch: {K} vs {K2}") + out_shape = (batch, M, N) + batch_count = batch + else: # 4D + batch1, batch2 = a.shape[0], a.shape[1] + M, K = a.shape[2], a.shape[3] + K2, N = b.shape[2], b.shape[3] + if b.shape[0] != batch1 or b.shape[1] != batch2: + raise ValueError( + f"Batch dimensions mismatch: ({batch1}, {batch2}) vs ({b.shape[0]}, {b.shape[1]})" + ) + if K != K2: + raise ValueError(f"Inner dimension mismatch: {K} vs {K2}") + out_shape = (batch1, batch2, M, N) + batch_count = batch1 * batch2 + + # Validate output + if out is not None: + if out.shape != out_shape: + raise ValueError(f"out shape {out.shape} does not match expected {out_shape}") + if out.dtype != a.dtype: + raise ValueError(f"out dtype {out.dtype} does not match input dtype {a.dtype}") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _batched_matmul_native(a, b, M, N, K, batch_count, out_shape, out=out) + else: + return _batched_matmul_cpu(a, b, out=out) + + +def _batched_matmul_cpu(a: GPUArray, b: GPUArray, *, out: GPUArray | None = None) -> GPUArray: + """CPU implementation of batched_matmul.""" + a_np = a.to_numpy() + b_np = b.to_numpy() + result_np = np.matmul(a_np, b_np) + result = from_numpy(result_np) + + if out is not None: + # Copy result to output buffer + from ..ops.elementwise import copy_to + + copy_to(result, out) + return out + else: + return result + + +def _batched_matmul_loop( + a: GPUArray, b: GPUArray, out_shape: tuple[int, ...], *, out: GPUArray | None = None +) -> GPUArray: + """GPU batched matmul using loop over individual matmuls. + + This is a fallback for when CUTLASS strided batched GEMM is not available + (e.g., SM 120). Uses native matmul kernel for each batch element. + """ + from pygpukit.core.backend import get_native_module + + native = get_native_module() + + # Reshape to 3D for easier iteration: [batch, M, K] @ [batch, K, N] + if a.ndim == 4: + batch1, batch2 = a.shape[0], a.shape[1] + M, K = a.shape[2], a.shape[3] + N = b.shape[3] + total_batch = batch1 * batch2 + + a_3d = a.reshape(total_batch, M, K) + b_3d = b.reshape(total_batch, K, N) + else: + total_batch = a.shape[0] + M, K = a.shape[1], a.shape[2] + N = b.shape[2] + + a_3d = a + b_3d = b + + # Allocate output + if out is None: + out_native = native.empty(list(out_shape), native.DataType.Float32) + out = GPUArray._wrap_native(out_native) + + # Perform batched matmul via loop + for i in range(total_batch): + # Extract slice (creates view/copy depending on implementation) + a_i = a_3d.to_numpy()[i] + b_i = b_3d.to_numpy()[i] + + a_gpu = from_numpy(a_i) + b_gpu = from_numpy(b_i) + + # Compute matmul for this batch element + c_gpu = matmul(a_gpu, b_gpu) + + # Copy result to output + out_np = out.to_numpy() + if a.ndim == 4: + i1, i2 = i // batch2, i % batch2 + out_np[i1, i2] = c_gpu.to_numpy() + else: + out_np[i] = c_gpu.to_numpy() + out = from_numpy(out_np) + + return out + + +def _batched_matmul_native( + a: GPUArray, + b: GPUArray, + M: int, + N: int, + K: int, + batch_count: int, + out_shape: tuple[int, ...], + *, + out: GPUArray | None = None, +) -> GPUArray: + """Native cuBLASLt strided batched GEMM implementation.""" + from pygpukit.core.backend import get_native_module + from pygpukit.core.dtypes import float32 + + native = get_native_module() + + # Currently only FP32 supported via cuBLASLt strided batched + if a.dtype != float32: + warnings.warn( + f"batched_matmul: GPU kernel requires float32, got {a.dtype}. Using CPU fallback (slow)", + RuntimeWarning, + stacklevel=3, + ) + return _batched_matmul_cpu(a, b, out=out) + + # Compute strides for strided batched GEMM + strideA = M * K + strideB = K * N + strideC = M * N + + # Get native arrays + a_native = a._get_native() + b_native = b._get_native() + + # Allocate output if needed (using native allocation) + if out is None: + out_native = native.empty(list(out_shape), native.DataType.Float32) + out = GPUArray._wrap_native(out_native) + else: + out_native = out._get_native() + + # Call strided batched GEMM with CPU fallback for unsupported architectures + try: + native.gemm_strided_batched_fp32( + a_native, + b_native, + out_native, + M, + N, + K, + batch_count, + strideA, + strideB, + strideC, + ) + except RuntimeError: + # CUTLASS not available/failed (e.g., SM 120) - fall back to CPU + warnings.warn( + "batched_matmul: CUTLASS kernel failed, using CPU fallback (slow)", + RuntimeWarning, + stacklevel=3, + ) + return _batched_matmul_cpu(a, b, out=out) + + return out + + +def fp8_available() -> bool: + """Check if FP8 GEMM is available (any backend). + + Returns: + True if FP8 GEMM is available (requires SM90+ GPU). + """ + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return native.fp8_available() + else: + return False + + +def fp8_sm90_available() -> bool: + """Check if FP8 GEMM is available on SM90 (Hopper). + + Returns: + True if FP8 GEMM is available (requires SM90+ and CUTLASS SM90 support). + """ + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return native.fp8_sm90_available() + else: + return False + + +def fp8_sm100_available() -> bool: + """Check if FP8 GEMM is available on SM100 (Blackwell datacenter). + + This may work on SM120 (Blackwell GeForce) as a fallback since both + are Blackwell architecture. + + Returns: + True if FP8 GEMM is available (requires SM100+ and CUTLASS SM100 support). + """ + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return native.fp8_sm100_available() + else: + return False + + +def fp8_sm120_available() -> bool: + """Check if FP8 GEMM is available on SM120 (Blackwell GeForce). + + Note: Currently disabled due to CUTLASS bug #2902. + + Returns: + True if FP8 GEMM is available (requires SM120+ and CUTLASS SM120 support). + """ + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return native.fp8_sm120_available() + else: + return False + + +def fp8_fp8_sm120_available() -> bool: + """Check if Pure FP8 I/O GEMM is available on SM120 (Blackwell GeForce). + + This is for FP8 models where weights and activations are already in FP8 format. + + Returns: + True if Pure FP8 GEMM is available (requires SM120+ and CUTLASS SM120 support). + """ + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return native.fp8_fp8_sm120_available() + else: + return False + + +def matmul_fp8_fp8_sm120( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """Pure FP8 I/O matrix multiplication for SM120 (Blackwell GeForce). + + This function takes FP8 E4M3 inputs directly (no conversion from FP32), + performs the GEMM using CUTLASS FP8 kernels, and returns FP8 E4M3 output. + + This is optimized for FP8 models (Llama 3.1 FP8, etc.) where weights + and activations are already quantized to FP8. + + Args: + a: First input array (M x K), FP8 E4M3 stored as uint8. + b: Second input array (K x N), FP8 E4M3 stored as uint8. + Should be in ColumnMajor format (pre-transposed). + out: Optional output array (M x N), uint8. If provided, result is + written to this array instead of allocating a new one. + + Returns: + The result GPUArray (M x N), FP8 E4M3 stored as uint8. + + Raises: + ValueError: If arrays are not 2D, dtypes are not uint8, or dimensions don't match. + RuntimeError: If FP8 SM120 is not available. + + Example: + >>> import pygpukit as gk + >>> # Assuming A and B are already FP8 quantized (stored as uint8) + >>> A = gk.from_numpy(fp8_a_data) # [M, K] uint8 + >>> B = gk.from_numpy(fp8_b_data) # [K, N] uint8 (ColumnMajor) + >>> C = gk.ops.matmul_fp8_fp8_sm120(A, B) # [M, N] uint8 + """ + from pygpukit.core.dtypes import uint8 + + if a.ndim != 2: + raise ValueError( + f"matmul_fp8_fp8_sm120 requires 2D arrays, got {a.ndim}D for first argument" + ) + if b.ndim != 2: + raise ValueError( + f"matmul_fp8_fp8_sm120 requires 2D arrays, got {b.ndim}D for second argument" + ) + + if a.shape[1] != b.shape[0]: + raise ValueError( + f"matmul_fp8_fp8_sm120 dimension mismatch: {a.shape} @ {b.shape} " + f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)" + ) + + if a.dtype != uint8 or b.dtype != uint8: + raise ValueError("matmul_fp8_fp8_sm120 requires uint8 inputs (FP8 E4M3)") + + if not fp8_fp8_sm120_available(): + raise RuntimeError("Pure FP8 SM120 GEMM is not available. Requires SM120+ GPU.") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _matmul_fp8_fp8_sm120_native(a, b, out=out) + else: + raise RuntimeError("Pure FP8 SM120 GEMM requires native backend") + + +def _matmul_fp8_fp8_sm120_native( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """Native C++ implementation of Pure FP8 I/O GEMM for SM120.""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + + # Get native arrays + a_native = a._get_native() + b_native = b._get_native() + + # Allocate output if needed + if out is None: + M, K = a.shape + N = b.shape[1] + out_native = native.empty([M, N], native.DataType.UInt8) + out = GPUArray._wrap_native(out_native) + else: + out_native = out._get_native() + + # Call Pure FP8 GEMM + native.gemm_fp8_fp8_sm120(a_native, b_native, out_native) + + return out + + +def fp8_fp8_get_scale_sizes(M: int, N: int, K: int) -> tuple[int, int]: + """Get scale factor sizes for FP8 blockwise GEMM. + + Returns the required sizes for scale_A and scale_B arrays for the + given problem dimensions. These sizes depend on the internal tile + configuration of the CUTLASS kernel. + + Args: + M: Number of rows in A and output. + N: Number of columns in B and output. + K: Inner dimension (columns of A, rows of B). + + Returns: + Tuple of (scale_A_size, scale_B_size) as integers. + + Example: + >>> sfa_size, sfb_size = fp8_fp8_get_scale_sizes(256, 256, 256) + >>> scale_A = pk.from_numpy(np.ones(sfa_size, dtype=np.float32)) + >>> scale_B = pk.from_numpy(np.ones(sfb_size, dtype=np.float32)) + """ + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return native.fp8_fp8_get_scale_sizes(M, N, K) + else: + return (0, 0) + + +def matmul_fp8_fp8_blockwise_sm120( + a: GPUArray, + b: GPUArray, + scale_a: GPUArray, + scale_b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """Blockwise scaled FP8 I/O matrix multiplication for SM120. + + This function takes FP8 E4M3 inputs with per-block scale factors, + performs the GEMM using CUTLASS FP8 kernels, and returns FP8 E4M3 output. + + The scale factors are applied per block during the GEMM computation, + enabling better precision for FP8 models with varied value ranges. + + Args: + a: First input array (M x K), FP8 E4M3 stored as uint8. + b: Second input array (K x N), FP8 E4M3 stored as uint8. + Should be in ColumnMajor format (pre-transposed). + scale_a: Scale factors for A, float32. Size from fp8_fp8_get_scale_sizes(). + scale_b: Scale factors for B, float32. Size from fp8_fp8_get_scale_sizes(). + out: Optional output array (M x N), uint8. If provided, result is + written to this array instead of allocating a new one. + + Returns: + The result GPUArray (M x N), FP8 E4M3 stored as uint8. + + Raises: + ValueError: If arrays are not 2D, dtypes are wrong, or dimensions don't match. + RuntimeError: If FP8 SM120 is not available. + + Example: + >>> import pygpukit as gk + >>> from pygpukit.ops import fp8_fp8_get_scale_sizes, matmul_fp8_fp8_blockwise_sm120 + >>> M, N, K = 256, 256, 256 + >>> sfa_size, sfb_size = fp8_fp8_get_scale_sizes(M, N, K) + >>> scale_A = gk.from_numpy(np.ones(sfa_size, dtype=np.float32)) + >>> scale_B = gk.from_numpy(np.ones(sfb_size, dtype=np.float32)) + >>> C = matmul_fp8_fp8_blockwise_sm120(A_fp8, B_fp8, scale_A, scale_B) + """ + from pygpukit.core.dtypes import float32, uint8 + + if a.ndim != 2: + raise ValueError(f"matmul_fp8_fp8_blockwise_sm120 requires 2D arrays, got {a.ndim}D for A") + if b.ndim != 2: + raise ValueError(f"matmul_fp8_fp8_blockwise_sm120 requires 2D arrays, got {b.ndim}D for B") + + if a.shape[1] != b.shape[0]: + raise ValueError( + f"matmul_fp8_fp8_blockwise_sm120 dimension mismatch: {a.shape} @ {b.shape} " + f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)" + ) + + if a.dtype != uint8 or b.dtype != uint8: + raise ValueError("matmul_fp8_fp8_blockwise_sm120 requires uint8 inputs (FP8)") + + if scale_a.dtype != float32 or scale_b.dtype != float32: + raise ValueError("matmul_fp8_fp8_blockwise_sm120 requires float32 scale factors") + + if not fp8_fp8_sm120_available(): + raise RuntimeError("FP8 blockwise SM120 GEMM is not available. Requires SM120+.") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _matmul_fp8_fp8_blockwise_sm120_native(a, b, scale_a, scale_b, out=out) + else: + raise RuntimeError("FP8 blockwise SM120 GEMM requires native backend") + + +def _matmul_fp8_fp8_blockwise_sm120_native( + a: GPUArray, + b: GPUArray, + scale_a: GPUArray, + scale_b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """Native C++ implementation of blockwise FP8 I/O GEMM for SM120.""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + + # Get native arrays + a_native = a._get_native() + b_native = b._get_native() + scale_a_native = scale_a._get_native() + scale_b_native = scale_b._get_native() + + # Allocate output if needed + if out is None: + M, K = a.shape + N = b.shape[1] + out_native = native.empty([M, N], native.DataType.UInt8) + out = GPUArray._wrap_native(out_native) + else: + out_native = out._get_native() + + # Call blockwise FP8 GEMM + native.gemm_fp8_fp8_blockwise_sm120( + a_native, b_native, out_native, scale_a_native, scale_b_native + ) + + return out + + +def matmul_fp8_sm100( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """FP8 matrix multiplication for SM100 (Blackwell datacenter). + + This function takes FP32 inputs, internally quantizes them to FP8, + performs the GEMM using CUTLASS FP8 kernels with BF16 accumulation, + and returns the result as FP32. + + This may work on SM120 (Blackwell GeForce) as a fallback since both + are Blackwell architecture. + + Args: + a: First input array (M x K), FP32. + b: Second input array (K x N), FP32. + out: Optional output array (M x N), FP32. If provided, result is + written to this array instead of allocating a new one. + + Returns: + The result GPUArray (M x N), FP32. + + Raises: + ValueError: If arrays are not 2D, not FP32, or dimensions don't match. + RuntimeError: If FP8 SM100 GEMM is not available or kernel fails. + + Example: + >>> import pygpukit as gk + >>> A = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1) + >>> B = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1) + >>> C = gk.ops.matmul_fp8_sm100(A, B) + """ + from pygpukit.core.dtypes import float32 + + if a.ndim != 2: + raise ValueError(f"matmul_fp8_sm100 requires 2D arrays, got {a.ndim}D for first argument") + if b.ndim != 2: + raise ValueError(f"matmul_fp8_sm100 requires 2D arrays, got {b.ndim}D for second argument") + + if a.shape[1] != b.shape[0]: + raise ValueError( + f"matmul_fp8_sm100 dimension mismatch: {a.shape} @ {b.shape} " + f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)" + ) + + if a.dtype != float32 or b.dtype != float32: + raise ValueError("matmul_fp8_sm100 requires float32 inputs") + + if not fp8_sm100_available(): + raise RuntimeError( + "FP8 SM100 GEMM is not available. Requires SM100+ GPU and CUTLASS SM100 support." + ) + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _matmul_fp8_sm100_native(a, b, out=out) + else: + raise RuntimeError("FP8 SM100 GEMM requires native backend") + + +def _matmul_fp8_sm100_native( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """Native C++ implementation of FP8 GEMM for SM100.""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + + # Get native arrays + a_native = a._get_native() + b_native = b._get_native() + + # Allocate output if needed + if out is None: + M, K = a.shape + N = b.shape[1] + out_native = native.empty([M, N], native.DataType.Float32) + out = GPUArray._wrap_native(out_native) + else: + out_native = out._get_native() + + # Call FP8 GEMM + native.gemm_fp8_sm100(a_native, b_native, out_native) + + return out + + +def matmul_fp8_sm120( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """FP8 matrix multiplication for SM120 (Blackwell GeForce). + + This function takes FP32 inputs, internally quantizes them to FP8, + performs the GEMM using CUTLASS FP8 kernels with BF16 accumulation, + and returns the result as FP32. + + Args: + a: First input array (M x K), FP32. + b: Second input array (K x N), FP32. + out: Optional output array (M x N), FP32. If provided, result is + written to this array instead of allocating a new one. + + Returns: + The result GPUArray (M x N), FP32. + + Raises: + ValueError: If arrays are not 2D, not FP32, or dimensions don't match. + RuntimeError: If FP8 SM120 GEMM is not available or kernel fails. + + Example: + >>> import pygpukit as gk + >>> A = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1) + >>> B = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1) + >>> C = gk.ops.matmul_fp8_sm120(A, B) + """ + from pygpukit.core.dtypes import float32 + + if a.ndim != 2: + raise ValueError(f"matmul_fp8_sm120 requires 2D arrays, got {a.ndim}D for first argument") + if b.ndim != 2: + raise ValueError(f"matmul_fp8_sm120 requires 2D arrays, got {b.ndim}D for second argument") + + if a.shape[1] != b.shape[0]: + raise ValueError( + f"matmul_fp8_sm120 dimension mismatch: {a.shape} @ {b.shape} " + f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)" + ) + + if a.dtype != float32 or b.dtype != float32: + raise ValueError("matmul_fp8_sm120 requires float32 inputs") + + if not fp8_sm120_available(): + raise RuntimeError( + "FP8 SM120 GEMM is not available. Requires SM120+ GPU and CUTLASS SM120 support." + ) + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _matmul_fp8_sm120_native(a, b, out=out) + else: + raise RuntimeError("FP8 SM120 GEMM requires native backend") + + +def _matmul_fp8_sm120_native( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """Native C++ implementation of FP8 GEMM for SM120.""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + + # Get native arrays + a_native = a._get_native() + b_native = b._get_native() + + # Allocate output if needed + if out is None: + M, K = a.shape + N = b.shape[1] + out_native = native.empty([M, N], native.DataType.Float32) + out = GPUArray._wrap_native(out_native) + else: + out_native = out._get_native() + + # Call FP8 GEMM + native.gemm_fp8_sm120(a_native, b_native, out_native) + + return out + + +def matmul_fp8_sm90( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """FP8 matrix multiplication for SM90 (Hopper). + + This function takes FP32 inputs, internally quantizes them to FP8 with + per-tensor scaling, performs the GEMM using CUTLASS FP8 kernels, + and returns the result as FP32. + + Args: + a: First input array (M x K), FP32. + b: Second input array (K x N), FP32. + out: Optional output array (M x N), FP32. If provided, result is + written to this array instead of allocating a new one. + + Returns: + The result GPUArray (M x N), FP32. + + Raises: + ValueError: If arrays are not 2D, not FP32, or dimensions don't match. + RuntimeError: If FP8 SM90 GEMM is not available or kernel fails. + + Example: + >>> import pygpukit as gk + >>> A = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1) + >>> B = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1) + >>> C = gk.ops.matmul_fp8_sm90(A, B) + """ + from pygpukit.core.dtypes import float32 + + if a.ndim != 2: + raise ValueError(f"matmul_fp8_sm90 requires 2D arrays, got {a.ndim}D for first argument") + if b.ndim != 2: + raise ValueError(f"matmul_fp8_sm90 requires 2D arrays, got {b.ndim}D for second argument") + + if a.shape[1] != b.shape[0]: + raise ValueError( + f"matmul_fp8_sm90 dimension mismatch: {a.shape} @ {b.shape} " + f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)" + ) + + if a.dtype != float32 or b.dtype != float32: + raise ValueError("matmul_fp8_sm90 requires float32 inputs") + + if not fp8_sm90_available(): + raise RuntimeError( + "FP8 SM90 GEMM is not available. Requires SM90+ GPU and CUTLASS SM90 support." + ) + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _matmul_fp8_sm90_native(a, b, out=out) + else: + raise RuntimeError("FP8 SM90 GEMM requires native backend") + + +def _matmul_fp8_sm90_native( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """Native C++ implementation of FP8 GEMM for SM90.""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + + # Get native arrays + a_native = a._get_native() + b_native = b._get_native() + + # Allocate output if needed + if out is None: + M, K = a.shape + N = b.shape[1] + out_native = native.empty([M, N], native.DataType.Float32) + out = GPUArray._wrap_native(out_native) + else: + out_native = out._get_native() + + # Call FP8 GEMM + native.gemm_fp8_sm90(a_native, b_native, out_native) + + return out + + +def nvf4_bf16_sm120_available() -> bool: + """Check if NVF4 (4-bit) BF16 GEMM is available on SM120 (Blackwell GeForce). + + This variant uses NVF4 (4-bit float) for 2x memory bandwidth compared to FP8, + making it ideal for memory-bound LLM inference workloads. + + Returns: + True if NVF4 BF16 SM120 GEMM is available, False otherwise. + """ + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return native.nvf4_bf16_sm120_available() + else: + return False + + +def matmul_nvf4_bf16_sm120( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """NVF4 (4-bit) GEMM with BF16 input/output for SM120 (Blackwell GeForce). + + This variant uses NVF4 (float_e2m1_t, 4-bit) for the internal computation, + providing 2x memory bandwidth compared to FP8. Ideal for memory-bound + LLM inference workloads. + + Data flow: BF16 input -> NVF4 quantize with block scaling -> GEMM -> BF16 output + + Args: + a: First input array (M x K), BF16. + b: Second input array (K x N), BF16. + out: Optional output array (M x N), BF16. + + Returns: + The result GPUArray (M x N), BF16. + + Raises: + ValueError: If arrays are not 2D, not BF16, or dimensions don't match. + RuntimeError: If NVF4 BF16 SM120 GEMM is not available. + """ + from pygpukit.core.dtypes import bfloat16 + + if a.ndim != 2: + raise ValueError(f"matmul_nvf4_bf16_sm120 requires 2D arrays, got {a.ndim}D") + if b.ndim != 2: + raise ValueError(f"matmul_nvf4_bf16_sm120 requires 2D arrays, got {b.ndim}D") + + if a.shape[1] != b.shape[0]: + raise ValueError(f"matmul_nvf4_bf16_sm120 dimension mismatch: {a.shape} @ {b.shape}") + + if a.dtype != bfloat16 or b.dtype != bfloat16: + raise ValueError("matmul_nvf4_bf16_sm120 requires bfloat16 inputs") + + if not nvf4_bf16_sm120_available(): + raise RuntimeError("NVF4 BF16 SM120 GEMM is not available. Requires SM120+ GPU.") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _matmul_nvf4_bf16_sm120_native(a, b, out=out) + else: + raise RuntimeError("NVF4 BF16 SM120 GEMM requires native backend") + + +def _matmul_nvf4_bf16_sm120_native( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """Native C++ implementation of NVF4 BF16 GEMM for SM120.""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + + # Get native arrays + a_native = a._get_native() + b_native = b._get_native() + + # Allocate output if needed + if out is None: + M, K = a.shape + N = b.shape[1] + out_native = native.empty([M, N], native.DataType.BFloat16) + out = GPUArray._wrap_native(out_native) + else: + out_native = out._get_native() + + # Call NVF4 BF16 GEMM + native.gemm_nvf4_bf16_sm120(a_native, b_native, out_native) + + return out + + +# ============================================================================ +# GEMV Operations (M=1 special case) +# ============================================================================ + + +def gemv_nvf4_available() -> bool: + """Check if NVF4 GEMV is available (SM120+). + + Returns: + True if NVF4 GEMV is available on current GPU. + """ + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return native.gemv_nvf4_available() + else: + return False + + +def nvf4_get_sizes(K: int, N: int) -> tuple[int, int]: + """Get buffer sizes for NVF4-quantized weights. + + Args: + K: Inner dimension (input features). + N: Output dimension (output features). + + Returns: + Tuple of (data_size, scale_size) in bytes. + - data_size: Size for packed NVF4 weights [K/2, N] + - scale_size: Size for UE4M3 scale factors [K/32, N] + + Note: + NVF4 provides 4x compression vs BF16: + - BF16 weight size: K * N * 2 bytes + - NVF4 total size: K/2 * N + K/32 * N bytes + """ + data_size = (K // 2) * N + scale_size = ((K + 31) // 32) * N + return data_size, scale_size + + +def quantize_bf16_to_nvf4( + input: GPUArray, + out_data: GPUArray, + out_scale: GPUArray, +) -> None: + """Quantize BF16 weights to NVF4 format with block scaling. + + This quantizes BF16 weights to 4-bit NVF4 format with UE4M3 scale factors. + Each 32-element block shares one scale factor. + + Args: + input: BF16 weight matrix [K, N]. + out_data: Pre-allocated buffer for packed NVF4 data [K/2, N] (uint8). + out_scale: Pre-allocated buffer for scale factors [K/32, N] (uint8). + + Raises: + ValueError: If input is not 2D BF16, or buffers have wrong size. + RuntimeError: If NVF4 is not available. + + Note: + NVF4 values: {0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0} and negatives. + Block size: 32 elements per scale factor. + """ + from pygpukit.core.dtypes import bfloat16 + + if input.ndim != 2: + raise ValueError(f"quantize_bf16_to_nvf4 requires 2D input, got {input.ndim}D") + + if input.dtype != bfloat16: + raise ValueError(f"quantize_bf16_to_nvf4 requires bfloat16 input, got {input.dtype}") + + if not gemv_nvf4_available(): + raise RuntimeError("NVF4 quantization not available. Requires SM120+ GPU.") + + K, N = input.shape + expected_data_size, expected_scale_size = nvf4_get_sizes(K, N) + + # Validate buffer sizes (count elements) + actual_data_size = ( + out_data.shape[0] * out_data.shape[1] if out_data.ndim == 2 else out_data.size + ) + actual_scale_size = ( + out_scale.shape[0] * out_scale.shape[1] if out_scale.ndim == 2 else out_scale.size + ) + + if actual_data_size < expected_data_size: + raise ValueError(f"out_data buffer too small: {actual_data_size} < {expected_data_size}") + if actual_scale_size < expected_scale_size: + raise ValueError(f"out_scale buffer too small: {actual_scale_size} < {expected_scale_size}") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + input_native = input._get_native() + data_native = out_data._get_native() + scale_native = out_scale._get_native() + native.quantize_bf16_to_nvf4(input_native, data_native, scale_native) + + +def gemv_nvf4_bf16( + a: GPUArray, + b_data: GPUArray, + b_scale: GPUArray, + *, + out: GPUArray | None = None, + alpha: float = 1.0, +) -> GPUArray: + """NVF4 GEMV: C[N] = alpha * A[K] @ B[K,N] (NVF4 quantized). + + This performs matrix-vector multiplication where the weight matrix B + is pre-quantized to NVF4 format with block scaling. + + Args: + a: Input vector [K], BF16. + b_data: Packed NVF4 weight data [K/2, N], uint8. + b_scale: UE4M3 scale factors [K/32, N], uint8. + out: Optional output vector [N], BF16. + alpha: Scaling factor (default 1.0). + + Returns: + Output vector [N], BF16. + + Raises: + ValueError: If shapes or dtypes don't match. + RuntimeError: If NVF4 GEMV is not available. + + Note: + For LLM inference decode path (M=1), NVF4 provides 4x bandwidth + reduction vs BF16, which is critical for memory-bound workloads. + """ + from pygpukit.core.dtypes import bfloat16 + + if a.ndim != 1: + raise ValueError(f"gemv_nvf4_bf16 requires 1D input vector, got {a.ndim}D") + + if a.dtype != bfloat16: + raise ValueError(f"gemv_nvf4_bf16 requires bfloat16 input, got {a.dtype}") + + if not gemv_nvf4_available(): + raise RuntimeError("NVF4 GEMV not available. Requires SM120+ GPU.") + + # Infer N from b_data shape: [K/2, N] + if b_data.ndim == 2: + N = b_data.shape[1] + else: + raise ValueError(f"b_data must be 2D [K/2, N], got {b_data.ndim}D") + + # Validate output + if out is not None: + if out.shape != (N,): + raise ValueError(f"out shape {out.shape} does not match expected ({N},)") + if out.dtype != bfloat16: + raise ValueError(f"out dtype {out.dtype} must be bfloat16") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + + a_native = a._get_native() + data_native = b_data._get_native() + scale_native = b_scale._get_native() + + if out is None: + out_native = native.empty([N], native.DataType.BFloat16) + out = GPUArray._wrap_native(out_native) + else: + out_native = out._get_native() + + native.gemv_nvf4_bf16(a_native, data_native, scale_native, out_native, alpha) + + return out + else: + raise RuntimeError("NVF4 GEMV requires native backend") + + +def gemv_bf16( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, + alpha: float = 1.0, + beta: float = 0.0, +) -> GPUArray: + """BF16 GEMV: C[N] = alpha * A[K] @ B[K,N] + beta * C[N]. + + Standard BF16 matrix-vector multiplication without quantization. + + Args: + a: Input vector [K], BF16. + b: Weight matrix [K, N], BF16 (row-major). + out: Optional output vector [N], BF16. + alpha: Scaling factor for A @ B (default 1.0). + beta: Scaling factor for existing C (default 0.0). + + Returns: + Output vector [N], BF16. + + Raises: + ValueError: If shapes or dtypes don't match. + """ + from pygpukit.core.dtypes import bfloat16 + + if a.ndim != 1: + raise ValueError(f"gemv_bf16 requires 1D input vector, got {a.ndim}D") + + if b.ndim != 2: + raise ValueError(f"gemv_bf16 requires 2D weight matrix, got {b.ndim}D") + + if a.dtype != bfloat16 or b.dtype != bfloat16: + raise ValueError("gemv_bf16 requires bfloat16 inputs") + + K = a.shape[0] + if b.shape[0] != K: + raise ValueError(f"gemv_bf16 dimension mismatch: A[{K}] vs B[{b.shape[0]}, {b.shape[1]}]") + + N = b.shape[1] + + # Validate output + if out is not None: + if out.shape != (N,): + raise ValueError(f"out shape {out.shape} does not match expected ({N},)") + if out.dtype != bfloat16: + raise ValueError(f"out dtype {out.dtype} must be bfloat16") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + + a_native = a._get_native() + b_native = b._get_native() + + if out is None: + out_native = native.empty([N], native.DataType.BFloat16) + out = GPUArray._wrap_native(out_native) + else: + out_native = out._get_native() + + native.gemv_bf16(a_native, b_native, out_native, alpha, beta) + + return out + else: + # CPU fallback + a_np: np.ndarray[np.floating] = a.to_numpy().astype(np.float32) + b_np: np.ndarray[np.floating] = b.to_numpy().astype(np.float32) + result: np.ndarray[np.floating] = alpha * (a_np @ b_np) + if out is not None: + result = result + beta * out.to_numpy().astype(np.float32) + return from_numpy(result.astype(np.float16).view(np.uint16).astype(np.uint16)) + + +# ============================================================================ +# FP8 Operations +# ============================================================================ + + +def matmul_fp8( + a: GPUArray, + b: GPUArray, + *, + out: GPUArray | None = None, +) -> GPUArray: + """FP8 matrix multiplication with automatic backend selection. + + This function takes FP32 inputs, internally quantizes them to FP8, + performs the GEMM using the best available CUTLASS FP8 kernel, + and returns the result as FP32. + + Backend priority: + - SM120 (Blackwell GeForce): blockwise scaling (when CUTLASS bug #2902 is fixed) + - SM90 (Hopper): per-tensor scaling + + Args: + a: First input array (M x K), FP32. + b: Second input array (K x N), FP32. + out: Optional output array (M x N), FP32. If provided, result is + written to this array instead of allocating a new one. + + Returns: + The result GPUArray (M x N), FP32. + + Raises: + ValueError: If arrays are not 2D, not FP32, or dimensions don't match. + RuntimeError: If no FP8 GEMM backend is available. + + Example: + >>> import pygpukit as gk + >>> A = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1) + >>> B = gk.from_numpy(np.random.randn(1024, 1024).astype(np.float32) * 0.1) + >>> C = gk.ops.matmul_fp8(A, B) + """ + from pygpukit.core.dtypes import float32 + + if a.ndim != 2: + raise ValueError(f"matmul_fp8 requires 2D arrays, got {a.ndim}D for first argument") + if b.ndim != 2: + raise ValueError(f"matmul_fp8 requires 2D arrays, got {b.ndim}D for second argument") + + if a.shape[1] != b.shape[0]: + raise ValueError( + f"matmul_fp8 dimension mismatch: {a.shape} @ {b.shape} " + f"(inner dimensions {a.shape[1]} and {b.shape[0]} must match)" + ) + + if a.dtype != float32 or b.dtype != float32: + raise ValueError("matmul_fp8 requires float32 inputs") + + if not fp8_available(): + raise RuntimeError("FP8 GEMM is not available. Requires SM90+ GPU and CUTLASS support.") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + + # Get native arrays + a_native = a._get_native() + b_native = b._get_native() + + # Allocate output if needed + if out is None: + M, K = a.shape + N = b.shape[1] + out_native = native.empty([M, N], native.DataType.Float32) + out = GPUArray._wrap_native(out_native) + else: + out_native = out._get_native() + + # Call auto-dispatch FP8 GEMM + native.gemm_fp8(a_native, b_native, out_native) + + return out + else: + raise RuntimeError("FP8 GEMM requires native backend") diff --git a/src/pygpukit/ops/nn.py b/src/pygpukit/ops/nn.py index 3d29861..1637abf 100644 --- a/src/pygpukit/ops/nn.py +++ b/src/pygpukit/ops/nn.py @@ -112,6 +112,67 @@ def _silu_native(a: GPUArray, *, out: GPUArray | None = None) -> GPUArray: return GPUArray._wrap_native(c_native) +def sigmoid(a: GPUArray, *, out: GPUArray | None = None) -> GPUArray: + """Sigmoid activation: y = 1 / (1 + exp(-x)). + + Args: + a: Input array. + out: Optional pre-allocated output array. + + Returns: + A new GPUArray containing the sigmoid-activated values. + """ + _validate_float_dtype(a, "sigmoid") + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + a_native = a._get_native() + + if out is not None: + out_native = out._get_native() + native.sigmoid_(a_native, out_native) + return out + else: + return GPUArray._wrap_native(native.sigmoid(a_native)) + else: + x = a.to_numpy() + result = 1.0 / (1.0 + np.exp(-x)) + return from_numpy(result) + + +def tanh(a: GPUArray, *, out: GPUArray | None = None) -> GPUArray: + """Tanh activation. + + Args: + a: Input array. + out: Optional pre-allocated output array. + + Returns: + A new GPUArray containing the tanh-activated values. + """ + _validate_float_dtype(a, "tanh") + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + a_native = a._get_native() + + if out is not None: + out_native = out._get_native() + native.tanh_(a_native, out_native) + return out + else: + return GPUArray._wrap_native(native.tanh(a_native)) + else: + x = a.to_numpy() + return from_numpy(np.tanh(x)) + + # ============================================================================= # Normalization Layers # ============================================================================= @@ -128,7 +189,7 @@ def layernorm( Computes: (x - mean) / sqrt(var + eps) * gamma + beta Args: - input: Input array of shape [batch, features]. + input: Input array of shape [batch, features] or [batch, seq_len, features]. gamma: Scale parameter of shape [features]. beta: Bias parameter of shape [features]. eps: Small epsilon for numerical stability. @@ -141,19 +202,36 @@ def layernorm( """ _validate_float_dtype(input, "layernorm") - if input.ndim != 2: - raise ValueError(f"layernorm expects 2D input [batch, features], got {input.ndim}D") + if input.ndim not in (2, 3): + raise ValueError(f"layernorm expects 2D or 3D input, got {input.ndim}D") if gamma.ndim != 1 or beta.ndim != 1: raise ValueError("layernorm expects 1D gamma and beta") if input.dtype != gamma.dtype or input.dtype != beta.dtype: raise ValueError("layernorm: all inputs must have same dtype") - features = input.shape[1] + features = input.shape[-1] # Last dimension is features if gamma.shape[0] != features or beta.shape[0] != features: raise ValueError( f"layernorm: gamma/beta size {gamma.shape[0]} must match features {features}" ) + # Handle 3D input by reshaping to 2D, processing, and reshaping back + if input.ndim == 3: + batch, seq_len, feat = input.shape + input_2d = input.reshape(batch * seq_len, feat) + result_2d = _layernorm_dispatch(input_2d, gamma, beta, eps) + return result_2d.reshape(batch, seq_len, feat) + else: + return _layernorm_dispatch(input, gamma, beta, eps) + + +def _layernorm_dispatch( + input: GPUArray, + gamma: GPUArray, + beta: GPUArray, + eps: float, +) -> GPUArray: + """Dispatch layernorm to native or CPU implementation.""" backend = get_backend() if isinstance(backend, NativeBackend) and backend.is_available(): diff --git a/src/pygpukit/ops/reduction.py b/src/pygpukit/ops/reduction.py index aa3df5f..6e786b5 100644 --- a/src/pygpukit/ops/reduction.py +++ b/src/pygpukit/ops/reduction.py @@ -130,35 +130,45 @@ def _max_native(a: GPUArray) -> GPUArray: return GPUArray._wrap_native(c_native) -def softmax(input: GPUArray) -> GPUArray: - """Softmax activation applied row-wise. +def softmax(input: GPUArray, axis: int = -1) -> GPUArray: + """Softmax activation along the specified axis. Computes: y[i] = exp(x[i] - max(x)) / sum(exp(x - max(x))) Args: - input: Input array of shape [batch, features]. + input: Input array of shape [..., features]. + Supports 2D, 3D, and 4D tensors. + axis: The axis along which to compute softmax (default: -1, last axis). Returns: - A new GPUArray containing the softmax output. + A new GPUArray containing the softmax output, same shape as input. Raises: - ValueError: If input is not 2D or dtype is not a float type. + ValueError: If dtype is not a float type or axis is invalid. """ _validate_float_dtype(input, "softmax") - if input.ndim != 2: - raise ValueError(f"softmax expects 2D input [batch, features], got {input.ndim}D") + if input.ndim < 2: + raise ValueError(f"softmax expects at least 2D input, got {input.ndim}D") + if input.ndim > 4: + raise ValueError(f"softmax supports up to 4D input, got {input.ndim}D") + + # Normalize axis + if axis < 0: + axis = input.ndim + axis + if axis != input.ndim - 1: + raise ValueError(f"softmax currently only supports axis=-1 (last axis), got axis={axis}") backend = get_backend() if isinstance(backend, NativeBackend) and backend.is_available(): - return _softmax_native(input) + return _softmax_native_nd(input) else: - return _softmax_cpu(input) + return _softmax_cpu_nd(input) def _softmax_cpu(input: GPUArray) -> GPUArray: - """CPU implementation of softmax.""" + """CPU implementation of softmax for 2D tensors.""" x = input.to_numpy() # Numerical stability: subtract max x_max = x.max(axis=1, keepdims=True) @@ -166,11 +176,126 @@ def _softmax_cpu(input: GPUArray) -> GPUArray: return from_numpy(exp_x / exp_x.sum(axis=1, keepdims=True)) +def _softmax_cpu_nd(input: GPUArray) -> GPUArray: + """CPU implementation of softmax for N-D tensors (axis=-1).""" + x = input.to_numpy() + # Numerical stability: subtract max along last axis + x_max = x.max(axis=-1, keepdims=True) + exp_x = np.exp(x - x_max) + return from_numpy(exp_x / exp_x.sum(axis=-1, keepdims=True)) + + def _softmax_native(input: GPUArray) -> GPUArray: - """Native C++ CUDA implementation of softmax (zero-copy).""" + """Native C++ CUDA implementation of softmax (zero-copy) for 2D tensors.""" from pygpukit.core.backend import get_native_module native = get_native_module() input_native = input._get_native() c_native = native.softmax(input_native) return GPUArray._wrap_native(c_native) + + +def _softmax_native_nd(input: GPUArray) -> GPUArray: + """Native C++ CUDA implementation of softmax for N-D tensors. + + Flattens leading dimensions into a single batch dimension, + applies softmax along the last axis, then reshapes back. + """ + from pygpukit.core.backend import get_native_module + + native = get_native_module() + original_shape = input.shape + + # Flatten all but last dimension into batch + features = original_shape[-1] + batch_size = 1 + for dim in original_shape[:-1]: + batch_size *= dim + + # Reshape to 2D [batch, features] + input_2d = input.reshape((batch_size, features)) + input_native = input_2d._get_native() + + # Apply softmax + c_native = native.softmax(input_native) + result_2d = GPUArray._wrap_native(c_native) + + # Reshape back to original shape + return result_2d.reshape(original_shape) + + +def min(a: GPUArray) -> GPUArray: + """Min of all elements. + + Args: + a: Input array (float types). + + Returns: + A scalar GPUArray (shape [1]) containing the minimum value. + """ + _validate_float_dtype(a, "min") + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return GPUArray._wrap_native(native.min(a._get_native())) + else: + a_np = a.to_numpy() + return from_numpy(np.array([np.min(a_np)], dtype=a_np.dtype)) + + +def argmax(a: GPUArray) -> GPUArray: + """Index of maximum element. + + Args: + a: Input array (float types). + + Returns: + A scalar GPUArray (shape [1], dtype int64) containing the index of the maximum value. + """ + _validate_float_dtype(a, "argmax") + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return GPUArray._wrap_native(native.argmax(a._get_native())) + else: + a_np = a.to_numpy() + return from_numpy(np.array([np.argmax(a_np)], dtype=np.int64)) + + +def sum_axis(a: GPUArray, axis: int) -> GPUArray: + """Sum along specified axis for 2D tensors. + + Args: + a: Input 2D array [M, N] (float types). + axis: Axis to sum along (0 or 1). + axis=0: sum rows -> output [N] + axis=1: sum columns -> output [M] + + Returns: + A GPUArray with the sum along the specified axis. + + Raises: + ValueError: If input is not 2D or axis is not 0 or 1. + """ + _validate_float_dtype(a, "sum_axis") + if a.ndim != 2: + raise ValueError(f"sum_axis requires 2D input, got {a.ndim}D") + if axis not in (0, 1): + raise ValueError(f"sum_axis: axis must be 0 or 1, got {axis}") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return GPUArray._wrap_native(native.sum_axis(a._get_native(), axis)) + else: + a_np = a.to_numpy() + return from_numpy(np.sum(a_np, axis=axis)) diff --git a/src/pygpukit/ops/tensor.py b/src/pygpukit/ops/tensor.py index cbf1784..0583615 100644 --- a/src/pygpukit/ops/tensor.py +++ b/src/pygpukit/ops/tensor.py @@ -188,6 +188,199 @@ def _transpose_3d_021_native(input: GPUArray, *, out: GPUArray | None = None) -> return GPUArray._wrap_native(c_native) +def transpose_4d_0213(input: GPUArray, *, out: GPUArray | None = None) -> GPUArray | None: + """Transpose 4D tensor: [d0, d1, d2, d3] -> [d0, d2, d1, d3]. + + Swaps axes 1 and 2 while keeping axes 0 and 3 in place. + Common in attention operations to convert: + - [batch, seq, heads, dim] -> [batch, heads, seq, dim] + + Args: + input: 4D tensor to transpose. + out: Optional pre-allocated output buffer for CUDA Graph capture. + If provided, must have shape [d0, d2, d1, d3] and same dtype as input. + + Returns: + Transposed tensor with axes 1 and 2 swapped. + Returns None if out is provided (in-place operation). + """ + _validate_float_dtype(input, "transpose_4d_0213") + + if input.ndim != 4: + raise ValueError(f"transpose_4d_0213 expects 4D input, got {input.ndim}D") + + backend = get_backend() + + # Native transpose_4d_0213 supports float32/float16/bfloat16 + if isinstance(backend, NativeBackend) and backend.is_available(): + dtype_str = str(input.dtype) + if dtype_str in ("float32", "float16", "bfloat16"): + return _transpose_4d_0213_native(input, out=out) + else: + if out is not None: + raise NotImplementedError( + "transpose_4d_0213: out parameter not supported for CPU fallback" + ) + return _transpose_4d_0213_cpu(input) + else: + if out is not None: + raise NotImplementedError( + "transpose_4d_0213: out parameter not supported for CPU fallback" + ) + return _transpose_4d_0213_cpu(input) + + +def _transpose_4d_0213_cpu(input: GPUArray) -> GPUArray: + """CPU fallback for transpose_4d_0213.""" + x = input.to_numpy() + result = np.transpose(x, (0, 2, 1, 3)).copy() + return from_numpy(result) + + +def _transpose_4d_0213_native(input: GPUArray, *, out: GPUArray | None = None) -> GPUArray | None: + """Native C++ CUDA implementation of transpose_4d_0213.""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + input_native = input._get_native() + + if out is not None: + out_native = out._get_native() + native.transpose_4d_0213_(input_native, out_native) + return None + else: + c_native = native.transpose_4d_0213(input_native) + return GPUArray._wrap_native(c_native) + + +def transpose_3d_012(input: GPUArray, *, out: GPUArray | None = None) -> GPUArray | None: + """Transpose 3D tensor: [d0, d1, d2] -> [d0, d2, d1]. + + Swaps last two axes while keeping axis 0 in place. + Useful for attention operations where K needs to be transposed. + + Args: + input: 3D tensor to transpose. + out: Optional pre-allocated output buffer for CUDA Graph capture. + If provided, must have shape [d0, d2, d1] and same dtype as input. + + Returns: + Transposed tensor with last two axes swapped. + Returns None if out is provided (in-place operation). + """ + _validate_float_dtype(input, "transpose_3d_012") + + if input.ndim != 3: + raise ValueError(f"transpose_3d_012 expects 3D input, got {input.ndim}D") + + backend = get_backend() + + # Native transpose_3d_012 supports float32/float16/bfloat16 + if isinstance(backend, NativeBackend) and backend.is_available(): + dtype_str = str(input.dtype) + if dtype_str in ("float32", "float16", "bfloat16"): + return _transpose_3d_012_native(input, out=out) + else: + if out is not None: + raise NotImplementedError( + "transpose_3d_012: out parameter not supported for CPU fallback" + ) + return _transpose_3d_012_cpu(input) + else: + if out is not None: + raise NotImplementedError( + "transpose_3d_012: out parameter not supported for CPU fallback" + ) + return _transpose_3d_012_cpu(input) + + +def _transpose_3d_012_cpu(input: GPUArray) -> GPUArray: + """CPU implementation of transpose_3d_012.""" + x = input.to_numpy() + result = np.transpose(x, (0, 2, 1)).copy() + return from_numpy(result) + + +def _transpose_3d_012_native(input: GPUArray, *, out: GPUArray | None = None) -> GPUArray | None: + """Native C++ CUDA implementation of transpose_3d_012.""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + input_native = input._get_native() + + if out is not None: + out_native = out._get_native() + native.transpose_3d_012_(input_native, out_native) + return None + else: + c_native = native.transpose_3d_012(input_native) + return GPUArray._wrap_native(c_native) + + +def transpose_4d_0132(input: GPUArray, *, out: GPUArray | None = None) -> GPUArray | None: + """Transpose 4D tensor: [d0, d1, d2, d3] -> [d0, d1, d3, d2]. + + Swaps last two axes while keeping axes 0 and 1 in place. + Useful for K^T in attention operations. + + Args: + input: 4D tensor to transpose. + out: Optional pre-allocated output buffer for CUDA Graph capture. + If provided, must have shape [d0, d1, d3, d2] and same dtype as input. + + Returns: + Transposed tensor with last two axes swapped. + Returns None if out is provided (in-place operation). + """ + _validate_float_dtype(input, "transpose_4d_0132") + + if input.ndim != 4: + raise ValueError(f"transpose_4d_0132 expects 4D input, got {input.ndim}D") + + backend = get_backend() + + # Native transpose_4d_0132 supports float32/float16/bfloat16 + if isinstance(backend, NativeBackend) and backend.is_available(): + dtype_str = str(input.dtype) + if dtype_str in ("float32", "float16", "bfloat16"): + return _transpose_4d_0132_native(input, out=out) + else: + if out is not None: + raise NotImplementedError( + "transpose_4d_0132: out parameter not supported for CPU fallback" + ) + return _transpose_4d_0132_cpu(input) + else: + if out is not None: + raise NotImplementedError( + "transpose_4d_0132: out parameter not supported for CPU fallback" + ) + return _transpose_4d_0132_cpu(input) + + +def _transpose_4d_0132_cpu(input: GPUArray) -> GPUArray: + """CPU fallback for transpose_4d_0132.""" + x = input.to_numpy() + result = np.transpose(x, (0, 1, 3, 2)).copy() + return from_numpy(result) + + +def _transpose_4d_0132_native(input: GPUArray, *, out: GPUArray | None = None) -> GPUArray | None: + """Native C++ CUDA implementation of transpose_4d_0132.""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + input_native = input._get_native() + + if out is not None: + out_native = out._get_native() + native.transpose_4d_0132_(input_native, out_native) + return None + else: + c_native = native.transpose_4d_0132(input_native) + return GPUArray._wrap_native(c_native) + + # ============================================================================= # Reshape Operations # ============================================================================= diff --git a/src/pygpukit/ops/unary.py b/src/pygpukit/ops/unary.py index 0ddfbc6..616f99f 100644 --- a/src/pygpukit/ops/unary.py +++ b/src/pygpukit/ops/unary.py @@ -130,3 +130,129 @@ def _relu_native(a: GPUArray) -> GPUArray: a_native = a._get_native() c_native = native.relu(a_native) return GPUArray._wrap_native(c_native) + + +def sin(a: GPUArray) -> GPUArray: + """Element-wise sine. + + Args: + a: Input array (float types). + + Returns: + A new GPUArray containing sin(a). + """ + _validate_float_dtype(a, "sin") + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return GPUArray._wrap_native(native.sin(a._get_native())) + else: + return from_numpy(np.sin(a.to_numpy())) + + +def cos(a: GPUArray) -> GPUArray: + """Element-wise cosine. + + Args: + a: Input array (float types). + + Returns: + A new GPUArray containing cos(a). + """ + _validate_float_dtype(a, "cos") + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return GPUArray._wrap_native(native.cos(a._get_native())) + else: + return from_numpy(np.cos(a.to_numpy())) + + +def sqrt(a: GPUArray) -> GPUArray: + """Element-wise square root. + + Args: + a: Input array (float types). + + Returns: + A new GPUArray containing sqrt(a). + """ + _validate_float_dtype(a, "sqrt") + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return GPUArray._wrap_native(native.sqrt(a._get_native())) + else: + return from_numpy(np.sqrt(a.to_numpy())) + + +def rsqrt(a: GPUArray) -> GPUArray: + """Element-wise reciprocal square root: 1/sqrt(x). + + Args: + a: Input array (float types). + + Returns: + A new GPUArray containing 1/sqrt(a). + """ + _validate_float_dtype(a, "rsqrt") + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return GPUArray._wrap_native(native.rsqrt(a._get_native())) + else: + return from_numpy(1.0 / np.sqrt(a.to_numpy())) + + +def abs(a: GPUArray) -> GPUArray: + """Element-wise absolute value. + + Args: + a: Input array (float types). + + Returns: + A new GPUArray containing |a|. + """ + _validate_float_dtype(a, "abs") + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return GPUArray._wrap_native(native.abs(a._get_native())) + else: + return from_numpy(np.abs(a.to_numpy())) + + +def neg(a: GPUArray) -> GPUArray: + """Element-wise negation: -x. + + Args: + a: Input array (float types). + + Returns: + A new GPUArray containing -a. + """ + _validate_float_dtype(a, "neg") + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + return GPUArray._wrap_native(native.neg(a._get_native())) + else: + return from_numpy(-a.to_numpy()) diff --git a/tests/test_fp8_sm120.py b/tests/test_fp8_sm120.py new file mode 100644 index 0000000..fd72f34 --- /dev/null +++ b/tests/test_fp8_sm120.py @@ -0,0 +1,35 @@ +"""Test FP8 GEMM with compute-sanitizer.""" + +import numpy as np + +from pygpukit.core.factory import from_numpy +from pygpukit.ops import fp8_sm120_available, matmul_fp8_sm120 + +print(f"FP8 SM120 available: {fp8_sm120_available()}") + +if fp8_sm120_available(): + # Use exact tile size (single tile) to eliminate edge cases + M, N, K = 128, 128, 128 + print(f"Testing with exact tile size: M={M}, N={N}, K={K}") + + A = np.random.randn(M, K).astype(np.float32) * 0.1 # Small values for FP8 + B = np.random.randn(K, N).astype(np.float32) * 0.1 + + A_gpu = from_numpy(A) + B_gpu = from_numpy(B) + + print("Running FP8 GEMM...") + try: + C_gpu = matmul_fp8_sm120(A_gpu, B_gpu) + print("FP8 GEMM succeeded!") + C = C_gpu.to_numpy() + print(f"Output shape: {C.shape}, dtype: {C.dtype}") + + # Verify against numpy + C_ref = A @ B + rel_error = np.linalg.norm(C - C_ref) / np.linalg.norm(C_ref) + print(f"Relative error vs NumPy: {rel_error:.6e}") + except Exception as e: + print(f"FP8 GEMM failed: {e}") +else: + print("FP8 SM120 not available") diff --git a/tests/test_nvf4_bf16_sm120.py b/tests/test_nvf4_bf16_sm120.py new file mode 100644 index 0000000..0f323a7 --- /dev/null +++ b/tests/test_nvf4_bf16_sm120.py @@ -0,0 +1,136 @@ +"""Test NVF4-BF16 GEMM for SM120 (Blackwell GeForce).""" + +import numpy as np + +from pygpukit.core.factory import from_numpy +from pygpukit.ops import matmul_nvf4_bf16_sm120, nvf4_bf16_sm120_available + + +def bf16_to_f32(bf16_uint16: np.ndarray) -> np.ndarray: + """Convert BFloat16 (stored as uint16) to float32. + + BFloat16 is the top 16 bits of float32, so we just left-shift by 16. + """ + # Ensure input is uint16 + bf16_uint16 = bf16_uint16.astype(np.uint16) + + # Shift to get float32 bits + f32_bits = bf16_uint16.astype(np.uint32) << 16 + + # View as float32 + return f32_bits.view(np.float32) + + +def f32_to_bf16(f32: np.ndarray) -> np.ndarray: + """Convert float32 to BFloat16 (stored as uint16). + + Just take the top 16 bits of the float32 representation. + """ + f32 = f32.astype(np.float32) + f32_bits = f32.view(np.uint32) + bf16_bits = (f32_bits >> 16).astype(np.uint16) + return bf16_bits + + +def test_nvf4_bf16_gemm(): + """Test NVF4-BF16 GEMM correctness.""" + print(f"NVF4-BF16 SM120 available: {nvf4_bf16_sm120_available()}") + + if not nvf4_bf16_sm120_available(): + print("NVF4-BF16 SM120 not available, skipping test") + return + + # Test with simple values first: all 2.0 + # Expected result: 2.0 * 2.0 * K = 512 for K=128 + M, N, K = 128, 128, 128 + print(f"Testing with dimensions: M={M}, N={N}, K={K}") + + # Create input data in float32, then convert to BF16 (uint16) + A_f32 = np.full((M, K), 2.0, dtype=np.float32) + B_f32 = np.full((K, N), 2.0, dtype=np.float32) + + # Convert to BFloat16 representation (uint16) + A_bf16 = f32_to_bf16(A_f32) + B_bf16 = f32_to_bf16(B_f32) + + print(f"A[0,0] as uint16: {A_bf16[0, 0]} (0x{A_bf16[0, 0]:04X})") + print(f"B[0,0] as uint16: {B_bf16[0, 0]} (0x{B_bf16[0, 0]:04X})") + + # Upload to GPU + A_gpu = from_numpy(A_bf16) + B_gpu = from_numpy(B_bf16) + + print(f"A_gpu dtype: {A_gpu.dtype}") + print(f"B_gpu dtype: {B_gpu.dtype}") + + print("Running NVF4-BF16 GEMM...") + try: + C_gpu = matmul_nvf4_bf16_sm120(A_gpu, B_gpu) + print("NVF4-BF16 GEMM succeeded!") + + # Get result as uint16 (raw BFloat16 storage) + C_uint16 = C_gpu.to_numpy() + print(f"C[0,0] as uint16: {C_uint16[0, 0]} (0x{C_uint16[0, 0]:04X})") + + # Convert to float32 for verification + C_f32 = bf16_to_f32(C_uint16) + print(f"C[0,0] as float32: {C_f32[0, 0]}") + print(f"Output shape: {C_f32.shape}, dtype: {C_f32.dtype}") + + # Expected: 2.0 * 2.0 * 128 = 512.0 + expected = 512.0 + actual = C_f32[0, 0] + print(f"Expected: {expected}, Actual: {actual}") + + if abs(actual - expected) < 1.0: # Allow small tolerance for quantization + print("PASS: NVF4-BF16 GEMM produces correct result!") + else: + print(f"FAIL: Expected {expected}, got {actual}") + + # Test with NVF4-appropriate random values + # NVF4 values: {0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0} and negatives + print("\n--- Testing with NVF4-appropriate random values ---") + nvf4_values = np.array( + [0.5, 1.0, 1.5, 2.0, 3.0, 4.0] + ) # Positive values only for simpler test + A_rand = np.random.choice(nvf4_values, size=(M, K)).astype(np.float32) + B_rand = np.random.choice(nvf4_values, size=(K, N)).astype(np.float32) + + A_rand_bf16 = f32_to_bf16(A_rand) + B_rand_bf16 = f32_to_bf16(B_rand) + + A_rand_gpu = from_numpy(A_rand_bf16) + B_rand_gpu = from_numpy(B_rand_bf16) + + C_rand_gpu = matmul_nvf4_bf16_sm120(A_rand_gpu, B_rand_gpu) + C_rand_uint16 = C_rand_gpu.to_numpy() + C_rand_f32 = bf16_to_f32(C_rand_uint16) + + # Reference: use BF16 precision for comparison + A_rand_ref = bf16_to_f32(A_rand_bf16) + B_rand_ref = bf16_to_f32(B_rand_bf16) + C_ref = A_rand_ref @ B_rand_ref + + # Compare + abs_error = np.abs(C_rand_f32 - C_ref).mean() + ref_scale = np.abs(C_ref).mean() + rel_error = abs_error / ref_scale if ref_scale > 0 else abs_error + print(f"Mean absolute error: {abs_error:.6e}") + print(f"Reference mean absolute: {ref_scale:.6e}") + print(f"Relative error: {rel_error:.2%}") + + # With exact NVF4 values as input, quantization should be exact + if rel_error < 0.05: # Allow 5% for BF16 accumulation errors + print("PASS: NVF4-BF16 GEMM with random values!") + else: + print(f"FAIL: Large relative error {rel_error:.2%}") + + except Exception as e: + print(f"NVF4-BF16 GEMM failed: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + test_nvf4_bf16_gemm() diff --git a/third_party/cutlass b/third_party/cutlass index d55f6be..65e7e40 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit d55f6beeebb6df501a250dc82827db97660f06e0 +Subproject commit 65e7e401e2d4a6153f0bd66d761345c988198b2d