From 7da4d10cebd70decbe6c733e119e1757369b7832 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Sun, 14 Dec 2025 23:23:14 +0900 Subject: [PATCH 01/24] feat(jit): add JIT compiler stabilization (#55) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## NVRTC Error Handling - Add NvrtcErrorCode enum (C++ and Python) - NvrtcError now includes error code and compilation log - Expose compilation_log property in Python exceptions ## PTX ISA Version Detection & Fallback - Add get_recommended_arch() for automatic architecture selection - Add get_fallback_archs() for fallback architecture list - Auto-retry PTX loading with lower architectures on ISA mismatch ## Retry Logic for Transient Failures - Retry OutOfMemory, InternalError up to 3 times - Exponential backoff (100ms, 200ms, 400ms) ## JIT Warmup System - Add warmup() function for pre-initializing NVRTC - Support background warmup with callback - Add is_warmup_done(), get_warmup_error() for status ## Driver Version Documentation - Add get_driver_requirements() returning min requirements - Add check_driver_compatibility() for compatibility check - Minimum: CUDA 11.0+, SM 8.0 (Ampere)+ 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- native/bindings/core_bindings.cpp | 6 + native/bindings/jit_bindings.cpp | 53 ++++ native/core/device.cpp | 95 ++++++ native/core/device.hpp | 13 + native/core/types.hpp | 67 +++- native/jit/compiler.cpp | 19 +- native/jit/kernel.cpp | 20 +- src/pygpukit/__init__.py | 14 + src/pygpukit/jit/__init__.py | 24 +- src/pygpukit/jit/compiler.py | 495 +++++++++++++++++++++++++++++- 10 files changed, 790 insertions(+), 16 deletions(-) diff --git a/native/bindings/core_bindings.cpp b/native/bindings/core_bindings.cpp index 9c60904..e095bb9 100644 --- a/native/bindings/core_bindings.cpp +++ b/native/bindings/core_bindings.cpp @@ -49,6 +49,12 @@ void init_core_bindings(py::module_& m) { m.def("validate_compute_capability", &validate_compute_capability, py::arg("device_id") = 0, "Validate device compute capability (requires SM >= 80)"); + m.def("get_recommended_arch", &get_recommended_arch, py::arg("device_id") = 0, + "Get recommended -arch option for JIT compilation (e.g., 'sm_86')"); + m.def("get_fallback_archs", &get_fallback_archs, py::arg("device_id") = 0, + "Get fallback -arch options for older drivers (in order of preference)"); + m.def("is_arch_supported", &is_arch_supported, py::arg("arch"), + "Check if driver supports a given PTX architecture"); // GPUArray class py::class_(m, "GPUArray") diff --git a/native/bindings/jit_bindings.cpp b/native/bindings/jit_bindings.cpp index 98c1309..39b0db5 100644 --- a/native/bindings/jit_bindings.cpp +++ b/native/bindings/jit_bindings.cpp @@ -7,7 +7,60 @@ namespace py = pybind11; using namespace pygpukit; +// Custom exception for NVRTC errors with structured info +static PyObject* NvrtcErrorType = nullptr; + void init_jit_bindings(py::module_& m) { + // NvrtcErrorCode enum + py::enum_(m, "NvrtcErrorCode") + .value("Success", NvrtcErrorCode::Success) + .value("OutOfMemory", NvrtcErrorCode::OutOfMemory) + .value("ProgramCreationFailure", NvrtcErrorCode::ProgramCreationFailure) + .value("InvalidInput", NvrtcErrorCode::InvalidInput) + .value("InvalidProgram", NvrtcErrorCode::InvalidProgram) + .value("InvalidOption", NvrtcErrorCode::InvalidOption) + .value("Compilation", NvrtcErrorCode::Compilation) + .value("BuiltinOperationFailure", NvrtcErrorCode::BuiltinOperationFailure) + .value("NoNameExpressionsAfterCompilation", NvrtcErrorCode::NoNameExpressionsAfterCompilation) + .value("NoLoweredNamesBeforeCompilation", NvrtcErrorCode::NoLoweredNamesBeforeCompilation) + .value("NameExpressionNotValid", NvrtcErrorCode::NameExpressionNotValid) + .value("InternalError", NvrtcErrorCode::InternalError) + .value("NotLoaded", NvrtcErrorCode::NotLoaded) + .value("PtxLoadFailed", NvrtcErrorCode::PtxLoadFailed) + .value("FunctionNotFound", NvrtcErrorCode::FunctionNotFound) + .value("LaunchFailed", NvrtcErrorCode::LaunchFailed) + .export_values(); + + // Create custom NvrtcError exception type with code and log attributes + NvrtcErrorType = PyErr_NewExceptionWithDoc( + "_pygpukit_native.NvrtcError", + "NVRTC JIT compilation error with structured error information.\n\n" + "Attributes:\n" + " code (NvrtcErrorCode): Structured error code\n" + " compilation_log (str): NVRTC compiler output (if available)", + PyExc_RuntimeError, + nullptr + ); + m.attr("NvrtcError") = py::handle(NvrtcErrorType); + + // Register exception translator + py::register_exception_translator([](std::exception_ptr p) { + try { + if (p) std::rethrow_exception(p); + } catch (const NvrtcError& e) { + // Create exception with attributes + PyObject* exc = PyObject_CallFunction(NvrtcErrorType, "s", e.what()); + if (exc) { + PyObject_SetAttrString(exc, "code", py::cast(e.code()).ptr()); + PyObject_SetAttrString(exc, "compilation_log", py::cast(e.log()).ptr()); + PyErr_SetObject(NvrtcErrorType, exc); + Py_DECREF(exc); + } else { + PyErr_SetString(NvrtcErrorType, e.what()); + } + } + }); + // CompiledPTX struct py::class_(m, "CompiledPTX") .def_readonly("ptx", &CompiledPTX::ptx) diff --git a/native/core/device.cpp b/native/core/device.cpp index fcb3607..6011150 100644 --- a/native/core/device.cpp +++ b/native/core/device.cpp @@ -137,4 +137,99 @@ void validate_compute_capability(int device_id) { } } +std::string get_recommended_arch(int device_id) { + int sm = get_sm_version(device_id); + int driver_version = get_driver_version(); + + // Driver version is MAJOR*1000 + MINOR*10 + // e.g., CUDA 12.4 = 12040, CUDA 11.8 = 11080 + + // Clamp SM to what the driver supports + // CUDA 12.x supports SM 90 (Hopper) + // CUDA 11.8+ supports SM 90 + // CUDA 11.1-11.7 supports SM 86 + // CUDA 11.0 supports SM 80 + int max_supported_sm = 80; + + if (driver_version >= 12000) { + max_supported_sm = 90; // Hopper + } else if (driver_version >= 11080) { + max_supported_sm = 90; // 11.8 added SM 90 + } else if (driver_version >= 11010) { + max_supported_sm = 86; // 11.1 added SM 86 + } else { + max_supported_sm = 80; // SM 80 baseline + } + + // Use the minimum of actual SM and max supported + int target_sm = std::min(sm, max_supported_sm); + + // Ensure minimum SM 80 for PyGPUkit + if (target_sm < 80) { + target_sm = 80; + } + + return "sm_" + std::to_string(target_sm); +} + +std::vector get_fallback_archs(int device_id) { + int sm = get_sm_version(device_id); + std::vector archs; + + // Start with the actual SM, then add fallbacks + // Prefer SM versions, then compute versions (PTX only) + + // Add SM versions from current down to 80 + for (int target = sm; target >= 80; target -= (target > 86 ? 4 : 6)) { + archs.push_back("sm_" + std::to_string(target)); + // Add specific versions + if (target == 90) { + // After 90, try 89 (Ada), then 86, then 80 + archs.push_back("sm_89"); + } + if (target == 89 || target == 90) { + archs.push_back("sm_86"); + } + } + + // Finally add compute_80 as ultimate fallback (PTX only, JIT compiled by driver) + if (archs.empty() || archs.back() != "sm_80") { + archs.push_back("sm_80"); + } + archs.push_back("compute_80"); + + return archs; +} + +bool is_arch_supported(const std::string& arch) { + int driver_version = get_driver_version(); + + // Parse SM version from arch string + int sm_version = 0; + if (arch.find("sm_") == 0 || arch.find("compute_") == 0) { + size_t pos = arch.find('_'); + if (pos != std::string::npos) { + try { + sm_version = std::stoi(arch.substr(pos + 1)); + } catch (...) { + return false; + } + } + } else { + return false; + } + + // Check if driver supports this SM version + int max_sm = 80; + if (driver_version >= 12000) { + max_sm = 90; + } else if (driver_version >= 11080) { + max_sm = 90; + } else if (driver_version >= 11010) { + max_sm = 86; + } + + return sm_version <= max_sm && sm_version >= 80; +} + } // namespace pygpukit diff --git a/native/core/device.hpp b/native/core/device.hpp index 0352a53..847dd99 100644 --- a/native/core/device.hpp +++ b/native/core/device.hpp @@ -2,6 +2,7 @@ #include #include +#include namespace pygpukit { @@ -47,4 +48,16 @@ void validate_compute_capability(int device_id = 0); // Get SM version as integer (e.g., 86 for SM 8.6) int get_sm_version(int device_id = 0); +// Get recommended -arch option for JIT compilation (e.g., "sm_86") +// Based on current GPU's compute capability +std::string get_recommended_arch(int device_id = 0); + +// Get fallback -arch options for older drivers (in order of preference) +// Returns list like ["sm_80", "compute_80"] for fallback +std::vector get_fallback_archs(int device_id = 0); + +// Check if driver supports a given PTX architecture +// arch should be like "sm_86" or "compute_80" +bool is_arch_supported(const std::string& arch); + } // namespace pygpukit diff --git a/native/core/types.hpp b/native/core/types.hpp index ebeb5e2..76271c0 100644 --- a/native/core/types.hpp +++ b/native/core/types.hpp @@ -41,6 +41,51 @@ inline std::string dtype_name(DataType dtype) { using DevicePtr = void*; // Error handling + +// NVRTC error codes (matches nvrtcResult + custom codes) +enum class NvrtcErrorCode { + Success = 0, + OutOfMemory = 1, + ProgramCreationFailure = 2, + InvalidInput = 3, + InvalidProgram = 4, + InvalidOption = 5, + Compilation = 6, + BuiltinOperationFailure = 7, + NoNameExpressionsAfterCompilation = 8, + NoLoweredNamesBeforeCompilation = 9, + NameExpressionNotValid = 10, + InternalError = 11, + // Custom error codes (1000+) + NotLoaded = 1000, // NVRTC DLL not loaded + PtxLoadFailed = 1001, // cuModuleLoadData failed + FunctionNotFound = 1002, // cuModuleGetFunction failed + LaunchFailed = 1003, // cuLaunchKernel failed +}; + +// Get string name for error code +inline const char* nvrtc_error_name(NvrtcErrorCode code) { + switch (code) { + case NvrtcErrorCode::Success: return "Success"; + case NvrtcErrorCode::OutOfMemory: return "OutOfMemory"; + case NvrtcErrorCode::ProgramCreationFailure: return "ProgramCreationFailure"; + case NvrtcErrorCode::InvalidInput: return "InvalidInput"; + case NvrtcErrorCode::InvalidProgram: return "InvalidProgram"; + case NvrtcErrorCode::InvalidOption: return "InvalidOption"; + case NvrtcErrorCode::Compilation: return "Compilation"; + case NvrtcErrorCode::BuiltinOperationFailure: return "BuiltinOperationFailure"; + case NvrtcErrorCode::NoNameExpressionsAfterCompilation: return "NoNameExpressionsAfterCompilation"; + case NvrtcErrorCode::NoLoweredNamesBeforeCompilation: return "NoLoweredNamesBeforeCompilation"; + case NvrtcErrorCode::NameExpressionNotValid: return "NameExpressionNotValid"; + case NvrtcErrorCode::InternalError: return "InternalError"; + case NvrtcErrorCode::NotLoaded: return "NotLoaded"; + case NvrtcErrorCode::PtxLoadFailed: return "PtxLoadFailed"; + case NvrtcErrorCode::FunctionNotFound: return "FunctionNotFound"; + case NvrtcErrorCode::LaunchFailed: return "LaunchFailed"; + default: return "Unknown"; + } +} + class CudaError : public std::runtime_error { public: explicit CudaError(const std::string& msg) : std::runtime_error(msg) {} @@ -48,7 +93,27 @@ class CudaError : public std::runtime_error { class NvrtcError : public std::runtime_error { public: - explicit NvrtcError(const std::string& msg) : std::runtime_error(msg) {} + explicit NvrtcError(const std::string& msg) + : std::runtime_error(msg) + , code_(NvrtcErrorCode::InternalError) + , log_() {} + + NvrtcError(const std::string& msg, NvrtcErrorCode code) + : std::runtime_error(msg) + , code_(code) + , log_() {} + + NvrtcError(const std::string& msg, NvrtcErrorCode code, const std::string& log) + : std::runtime_error(msg) + , code_(code) + , log_(log) {} + + NvrtcErrorCode code() const { return code_; } + const std::string& log() const { return log_; } + +private: + NvrtcErrorCode code_; + std::string log_; }; } // namespace pygpukit diff --git a/native/jit/compiler.cpp b/native/jit/compiler.cpp index d250dc5..103f71c 100644 --- a/native/jit/compiler.cpp +++ b/native/jit/compiler.cpp @@ -6,9 +6,17 @@ namespace pygpukit { namespace { +// Convert nvrtc::Result to NvrtcErrorCode +NvrtcErrorCode to_error_code(nvrtc::Result result) { + return static_cast(static_cast(result)); +} + void check_nvrtc_error(nvrtc::Result result, const char* msg) { if (result != nvrtc::Result::Success) { - throw NvrtcError(std::string(msg) + ": " + nvrtc::get_error_string(result)); + throw NvrtcError( + std::string(msg) + ": " + nvrtc::get_error_string(result), + to_error_code(result) + ); } } @@ -17,7 +25,8 @@ void ensure_nvrtc_available() { throw NvrtcError( "NVRTC is not available. JIT compilation of custom kernels requires NVRTC. " "Pre-compiled GPU operations (matmul, add, mul) work without NVRTC. " - "For custom kernels, see: https://developer.nvidia.com/cuda-downloads" + "For custom kernels, see: https://developer.nvidia.com/cuda-downloads", + NvrtcErrorCode::NotLoaded ); } } @@ -76,7 +85,11 @@ CompiledPTX compile_to_ptx( if (result != nvrtc::Result::Success) { nvrtc::destroy_program(&prog); - throw NvrtcError("Compilation failed: " + log); + throw NvrtcError( + "Compilation failed: " + log, + NvrtcErrorCode::Compilation, + log + ); } // Get PTX diff --git a/native/jit/kernel.cpp b/native/jit/kernel.cpp index ccdc9d7..1c90092 100644 --- a/native/jit/kernel.cpp +++ b/native/jit/kernel.cpp @@ -71,11 +71,27 @@ void JITKernel::compile(const std::vector& options) { // Load module from PTX CUresult result = cuModuleLoadData(&module_, ptx_.c_str()); - check_cuda_driver_error(result, "Failed to load module from PTX"); + if (result != CUDA_SUCCESS) { + const char* error_str; + cuGetErrorString(result, &error_str); + throw NvrtcError( + std::string("Failed to load module from PTX: ") + (error_str ? error_str : "unknown error"), + NvrtcErrorCode::PtxLoadFailed + ); + } // Get function handle result = cuModuleGetFunction(&function_, module_, func_name_.c_str()); - check_cuda_driver_error(result, "Failed to get function from module"); + if (result != CUDA_SUCCESS) { + const char* error_str; + cuGetErrorString(result, &error_str); + cuModuleUnload(module_); + module_ = nullptr; + throw NvrtcError( + std::string("Function '") + func_name_ + "' not found in module: " + (error_str ? error_str : "unknown error"), + NvrtcErrorCode::FunctionNotFound + ); + } } void JITKernel::cleanup() { diff --git a/src/pygpukit/__init__.py b/src/pygpukit/__init__.py index e3370bf..b10d820 100644 --- a/src/pygpukit/__init__.py +++ b/src/pygpukit/__init__.py @@ -15,10 +15,17 @@ from pygpukit.core.stream import Stream, StreamManager, default_stream from pygpukit.jit.compiler import ( JITKernel, + NvrtcError, + NvrtcErrorCode, + check_driver_compatibility, + get_driver_requirements, get_nvrtc_path, get_nvrtc_version, + get_warmup_error, is_nvrtc_available, + is_warmup_done, jit, + warmup, ) from pygpukit.ops.basic import add, matmul, mul @@ -60,9 +67,16 @@ # JIT "jit", "JITKernel", + "NvrtcError", + "NvrtcErrorCode", "is_nvrtc_available", "get_nvrtc_version", "get_nvrtc_path", + "warmup", + "is_warmup_done", + "get_warmup_error", + "get_driver_requirements", + "check_driver_compatibility", # Operations "add", "mul", diff --git a/src/pygpukit/jit/__init__.py b/src/pygpukit/jit/__init__.py index e14325b..6d7a9d9 100644 --- a/src/pygpukit/jit/__init__.py +++ b/src/pygpukit/jit/__init__.py @@ -1,5 +1,25 @@ """JIT compilation module for PyGPUkit.""" -from pygpukit.jit.compiler import JITKernel, jit +from pygpukit.jit.compiler import ( + JITKernel, + NvrtcError, + NvrtcErrorCode, + check_driver_compatibility, + get_driver_requirements, + get_warmup_error, + is_warmup_done, + jit, + warmup, +) -__all__ = ["jit", "JITKernel"] +__all__ = [ + "jit", + "JITKernel", + "NvrtcError", + "NvrtcErrorCode", + "warmup", + "is_warmup_done", + "get_warmup_error", + "get_driver_requirements", + "check_driver_compatibility", +] diff --git a/src/pygpukit/jit/compiler.py b/src/pygpukit/jit/compiler.py index e8aba5b..170f52f 100644 --- a/src/pygpukit/jit/compiler.py +++ b/src/pygpukit/jit/compiler.py @@ -4,7 +4,7 @@ NVRTC is optional - use `is_nvrtc_available()` to check availability. If NVRTC is not available: -- JIT compilation will raise RuntimeError +- JIT compilation will raise NvrtcError - Pre-compiled kernels (matmul, add, etc.) will still work via the native backend - CPU simulation mode will continue to work """ @@ -13,9 +13,88 @@ import hashlib import re +from enum import IntEnum from typing import Any +class NvrtcErrorCode(IntEnum): + """NVRTC error codes for structured error handling. + + These codes map directly to NVRTC's nvrtcResult enum plus custom codes. + """ + + Success = 0 + OutOfMemory = 1 + ProgramCreationFailure = 2 + InvalidInput = 3 + InvalidProgram = 4 + InvalidOption = 5 + Compilation = 6 + BuiltinOperationFailure = 7 + NoNameExpressionsAfterCompilation = 8 + NoLoweredNamesBeforeCompilation = 9 + NameExpressionNotValid = 10 + InternalError = 11 + # Custom error codes (1000+) + NotLoaded = 1000 # NVRTC DLL not loaded + PtxLoadFailed = 1001 # cuModuleLoadData failed + FunctionNotFound = 1002 # cuModuleGetFunction failed + LaunchFailed = 1003 # cuLaunchKernel failed + + +class NvrtcError(RuntimeError): + """NVRTC JIT compilation error with structured information. + + Attributes: + code: Structured error code (NvrtcErrorCode) + compilation_log: NVRTC compiler output (if available) + + Example: + >>> try: + ... kernel = pygpukit.jit(bad_source, "my_kernel") + ... except pygpukit.NvrtcError as e: + ... print(f"Error code: {e.code.name}") + ... if e.compilation_log: + ... print(f"Compiler log: {e.compilation_log}") + """ + + def __init__( + self, + message: str, + code: NvrtcErrorCode | int = NvrtcErrorCode.InternalError, + compilation_log: str = "", + ) -> None: + super().__init__(message) + self._code = NvrtcErrorCode(code) if isinstance(code, int) else code + self._compilation_log = compilation_log + + @property + def code(self) -> NvrtcErrorCode: + """Return the structured error code.""" + return self._code + + @property + def compilation_log(self) -> str: + """Return the NVRTC compiler output log.""" + return self._compilation_log + + def __str__(self) -> str: + base = super().__str__() + return f"[{self._code.name}] {base}" + + +def _wrap_native_nvrtc_error(exc: Exception) -> NvrtcError: + """Convert native NvrtcError to Python NvrtcError.""" + code = getattr(exc, "code", NvrtcErrorCode.InternalError) + log = getattr(exc, "compilation_log", "") + + # Convert native enum to Python enum if needed + if hasattr(code, "value"): + code = code.value + + return NvrtcError(str(exc), code, log) + + def is_nvrtc_available() -> bool: """Check if NVRTC JIT compiler is available. @@ -101,6 +180,93 @@ def get_nvrtc_version() -> tuple[int, int] | None: return None +# ============================================================================ +# Driver Version Requirements +# ============================================================================ + +# Minimum supported CUDA driver version (CUDA 11.0) +# Version format: MAJOR*1000 + MINOR*10 (e.g., 11.0 = 11000) +MIN_DRIVER_VERSION = 11000 +MIN_DRIVER_VERSION_STR = "11.0" + +# Minimum required GPU architecture (Ampere) +MIN_SM_VERSION = 80 +MIN_SM_VERSION_STR = "SM 8.0 (Ampere)" + + +def get_driver_requirements() -> dict[str, str]: + """Get driver and hardware requirements for PyGPUkit. + + Returns: + Dictionary with minimum requirements and recommendations. + + Example: + >>> import pygpukit as gp + >>> reqs = gp.get_driver_requirements() + >>> print(reqs['min_driver_version']) + '11.0' + """ + return { + "min_driver_version": MIN_DRIVER_VERSION_STR, + "min_gpu_architecture": MIN_SM_VERSION_STR, + "recommended_driver_version": "12.0+", + "recommended_gpu": "RTX 30xx/40xx series or newer", + "supported_architectures": "Ampere (SM 80-86), Ada (SM 89), Hopper (SM 90)", + "notes": ( + "PyGPUkit requires Ampere or newer GPUs. " + "Older architectures (Pascal, Turing) are not supported. " + "For best performance, use the latest NVIDIA driver." + ), + } + + +def check_driver_compatibility() -> tuple[bool, str]: + """Check if current driver meets minimum requirements. + + Returns: + Tuple of (is_compatible, message) where is_compatible is True if the + driver meets requirements, and message contains details. + + Example: + >>> import pygpukit as gp + >>> ok, msg = gp.check_driver_compatibility() + >>> if not ok: + ... print(f"Warning: {msg}") + """ + try: + from pygpukit.core.backend import get_native_module, has_native_module + + if not has_native_module(): + return False, "Native module not available (CPU simulation mode)" + + native = get_native_module() + + if not native.is_cuda_available(): + return False, "CUDA is not available" + + driver_version = native.get_driver_version() + if driver_version < MIN_DRIVER_VERSION: + driver_str = f"{driver_version // 1000}.{(driver_version % 1000) // 10}" + return False, ( + f"Driver version {driver_str} is below minimum required " + f"{MIN_DRIVER_VERSION_STR}. Please update your NVIDIA driver." + ) + + sm_version = native.get_sm_version() + if sm_version < MIN_SM_VERSION: + return False, ( + f"GPU SM {sm_version // 10}.{sm_version % 10} is below minimum " + f"required {MIN_SM_VERSION_STR}. PyGPUkit requires Ampere or newer." + ) + + # All checks passed + driver_str = f"{driver_version // 1000}.{(driver_version % 1000) // 10}" + return True, f"Driver {driver_str}, SM {sm_version // 10}.{sm_version % 10}" + + except Exception as e: + return False, f"Error checking compatibility: {e}" + + class JITKernel: """A JIT-compiled CUDA kernel. @@ -165,12 +331,28 @@ def _compile(self) -> None: self._is_compiled = True self._ptx = f"// Simulated PTX for {self._name}" + # Retry configuration for transient errors + _MAX_RETRIES = 3 + _RETRY_DELAY_MS = 100 # Base delay in milliseconds + _RETRYABLE_ERRORS = { + NvrtcErrorCode.OutOfMemory, + NvrtcErrorCode.InternalError, + NvrtcErrorCode.BuiltinOperationFailure, + } + def _compile_native(self) -> None: """Compile using native C++ module (NVRTC). + Automatically selects appropriate -arch option based on GPU and driver. + Falls back to lower architectures if PTX loading fails. + Retries on transient errors (out of memory, internal errors). + Raises: - RuntimeError: If NVRTC is not available with helpful installation instructions. + NvrtcError: If NVRTC is not available or compilation fails. """ + import time + import warnings + from pygpukit.core.backend import _find_nvrtc_dll, get_native_module native = get_native_module() @@ -196,12 +378,164 @@ def _compile_native(self) -> None: " https://developer.nvidia.com/cuda-downloads\n\n" "Check availability: pygpukit.is_nvrtc_available()" ) - raise RuntimeError(msg) - - # Use native JITKernel which handles NVRTC compilation - self._kernel = native.JITKernel(self._source, self._name, self._options) - self._ptx = self._kernel.ptx - self._is_compiled = self._kernel.is_compiled + raise NvrtcError(msg, NvrtcErrorCode.NotLoaded) + + # Prepare options with auto arch selection + options = self._prepare_compile_options(native) + + # Try compilation with fallback on PTX load failure + fallback_archs = self._get_fallback_archs(native) + last_error: Exception | None = None + arch_used: str | None = None + + for arch_attempt, arch in enumerate(fallback_archs): + current_options = self._replace_arch_option(options, arch) + + # Retry loop for transient errors + for retry in range(self._MAX_RETRIES): + try: + self._kernel = native.JITKernel( + self._source, self._name, current_options + ) + self._ptx = self._kernel.ptx + self._is_compiled = self._kernel.is_compiled + arch_used = arch + + # Warn if fallback was used + if arch_attempt > 0: + warnings.warn( + f"JIT compilation succeeded using fallback architecture " + f"'{arch}'. Original architecture failed. Consider updating " + f"your NVIDIA driver for better compatibility.", + UserWarning, + stacklevel=4, + ) + return # Success + except Exception as e: + last_error = e + err_code = self._get_error_code(e) + err_msg = str(e) + + # Check if this is a retryable transient error + if err_code in self._RETRYABLE_ERRORS and retry < self._MAX_RETRIES - 1: + # Exponential backoff + delay = self._RETRY_DELAY_MS * (2**retry) / 1000.0 + time.sleep(delay) + continue + + # Compilation errors should not be retried + if "Compilation failed" in err_msg: + break + + # PTX load failure - try next fallback arch + is_ptx_load_error = ( + "PTX" in err_msg + or "module" in err_msg.lower() + or "CUDA_ERROR" in err_msg + or err_code == NvrtcErrorCode.PtxLoadFailed + ) + if is_ptx_load_error and arch_attempt < len(fallback_archs) - 1: + break # Try next arch + + # Other error - stop retrying this arch + break + + # If we reach here without returning, try next arch + # (unless it was a compilation error) + if last_error and "Compilation failed" in str(last_error): + break # Don't try other archs for syntax errors + + # All attempts failed + if last_error is not None: + if hasattr(last_error, "code") and hasattr(last_error, "compilation_log"): + raise _wrap_native_nvrtc_error(last_error) from None + msg = str(last_error) + if "Compilation failed" in msg: + raise NvrtcError(msg, NvrtcErrorCode.Compilation) from None + elif "not found in module" in msg or "Function" in msg: + raise NvrtcError(msg, NvrtcErrorCode.FunctionNotFound) from None + elif "PTX" in msg or "module" in msg.lower(): + raise NvrtcError(msg, NvrtcErrorCode.PtxLoadFailed) from None + else: + raise NvrtcError(msg, NvrtcErrorCode.InternalError) from None + + def _get_error_code(self, exc: Exception) -> NvrtcErrorCode: + """Extract error code from exception.""" + if hasattr(exc, "code"): + code = exc.code + if hasattr(code, "value"): + return NvrtcErrorCode(code.value) + elif isinstance(code, int): + try: + return NvrtcErrorCode(code) + except ValueError: + return NvrtcErrorCode.InternalError + return NvrtcErrorCode.InternalError + + def _prepare_compile_options(self, native: Any) -> list[str]: + """Prepare compilation options with auto arch selection.""" + options = list(self._options) + + # Check if user already specified -arch + has_arch = any( + opt.startswith("-arch=") or opt.startswith("--gpu-architecture=") + for opt in options + ) + + if not has_arch: + # Auto-select arch based on current GPU + try: + recommended_arch = native.get_recommended_arch() + options.append(f"-arch={recommended_arch}") + except Exception: + # Fallback to sm_80 (minimum supported) + options.append("-arch=sm_80") + + return options + + def _get_fallback_archs(self, native: Any) -> list[str]: + """Get list of architectures to try (primary + fallbacks).""" + # Check if user specified arch + user_arch = None + for opt in self._options: + if opt.startswith("-arch="): + user_arch = opt.split("=", 1)[1] + break + elif opt.startswith("--gpu-architecture="): + user_arch = opt.split("=", 1)[1] + break + + if user_arch: + # User specified arch - use it as primary, add fallbacks + archs = [user_arch] + try: + fallbacks = native.get_fallback_archs() + for fb in fallbacks: + if fb not in archs: + archs.append(fb) + except Exception: + archs.extend(["sm_86", "sm_80", "compute_80"]) + return archs + else: + # Auto-select - use recommended arch as primary + try: + return native.get_fallback_archs() + except Exception: + return ["sm_86", "sm_80", "compute_80"] + + def _replace_arch_option(self, options: list[str], new_arch: str) -> list[str]: + """Replace -arch option with new architecture.""" + result = [] + arch_found = False + for opt in options: + if opt.startswith("-arch=") or opt.startswith("--gpu-architecture="): + result.append(f"-arch={new_arch}") + arch_found = True + else: + result.append(opt) + if not arch_found: + result.append(f"-arch={new_arch}") + return result @property def source(self) -> str: @@ -298,3 +632,148 @@ def jit( >>> kernel(x, 0.5, n) """ return JITKernel(source, func, options, block_size) + + +# ============================================================================ +# JIT Warmup System +# ============================================================================ + +import threading +from typing import Callable + +# Global warmup state +_warmup_lock = threading.Lock() +_warmup_done = False +_warmup_thread: threading.Thread | None = None +_warmup_error: Exception | None = None + +# Warmup test kernel +_WARMUP_KERNEL_SOURCE = ''' +extern "C" __global__ void _pygpukit_warmup_kernel(float* x, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) x[idx] = x[idx]; +} +''' + + +def warmup( + background: bool = False, + callback: Callable[[], None] | None = None, +) -> bool: + """Warm up the JIT compiler. + + This function pre-initializes the NVRTC JIT compiler by compiling a simple + test kernel. This ensures that subsequent JIT compilations are faster as + the compiler is already loaded and initialized. + + Args: + background: If True, run warmup in a background thread. + callback: Optional callback to invoke when warmup completes (only used + when background=True). + + Returns: + True if warmup succeeded (or was already done), False if NVRTC is + not available. When background=True, returns True immediately and + warmup continues in background. + + Example: + >>> import pygpukit as gp + >>> # Synchronous warmup + >>> gp.warmup() + True + >>> # Background warmup + >>> gp.warmup(background=True, callback=lambda: print("Ready!")) + True + """ + global _warmup_done, _warmup_thread, _warmup_error + + with _warmup_lock: + if _warmup_done: + return _warmup_error is None + + if _warmup_thread is not None and _warmup_thread.is_alive(): + # Warmup already in progress + if not background: + _warmup_thread.join() + return True + + if background: + _warmup_thread = threading.Thread( + target=_do_warmup, + args=(callback,), + daemon=True, + ) + _warmup_thread.start() + return True + else: + return _do_warmup(callback) + + +def _do_warmup(callback: Callable[[], None] | None = None) -> bool: + """Perform the actual warmup.""" + global _warmup_done, _warmup_error + + try: + # Check if NVRTC is available + if not is_nvrtc_available(): + _warmup_error = NvrtcError( + "NVRTC not available", NvrtcErrorCode.NotLoaded + ) + _warmup_done = True + return False + + # Compile warmup kernel + try: + _ = JITKernel( + _WARMUP_KERNEL_SOURCE, + "_pygpukit_warmup_kernel", + options=[], # Use default arch + ) + except NvrtcError as e: + _warmup_error = e + _warmup_done = True + return False + + _warmup_done = True + _warmup_error = None + + if callback is not None: + try: + callback() + except Exception: + pass # Ignore callback errors + + return True + except Exception as e: + _warmup_error = e + _warmup_done = True + return False + + +def is_warmup_done() -> bool: + """Check if JIT warmup has completed. + + Returns: + True if warmup has completed (successfully or with error), False if + warmup is in progress or has not started. + + Example: + >>> import pygpukit as gp + >>> gp.warmup(background=True) + True + >>> # ... do other initialization ... + >>> while not gp.is_warmup_done(): + ... time.sleep(0.01) + >>> print("JIT compiler ready!") + """ + return _warmup_done + + +def get_warmup_error() -> Exception | None: + """Get the warmup error if warmup failed. + + Returns: + The exception that caused warmup to fail, or None if warmup succeeded + or has not completed. + """ + return _warmup_error From b7c5224eea3e6105fe92fa12b44f5b706e959392 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Sun, 14 Dec 2025 23:39:43 +0900 Subject: [PATCH 02/24] feat(cache): implement persistent kernel cache with LRU eviction (#54) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add disk-based PTX cache with architecture fingerprinting: Rust implementation (pygpukit-core): - PersistentCache: disk-based PTX storage with JSON index - ArchFingerprint: GPU characteristics for cache key (SM version, memory, driver) - LRU eviction by entry count and total size - TTL-based expiration with auto cleanup - Serde serialization for persistence Python bindings (pygpukit-python): - ArchFingerprint: GPU architecture fingerprint - PersistentCacheConfig: cache directory, size limits, TTL - PersistentCache: insert, get, remove, cleanup, stats - PersistentEntry: cached PTX with metadata - PersistentCacheStats: hit rate, evictions, errors Features: - Architecture-aware cache keys (SM version + driver version) - Automatic directory creation and index management - Size-based and count-based eviction policies - Cross-session cache persistence Tests: 5 new tests for persistent_cache (119 total) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- rust/Cargo.lock | 218 ++++- rust/Cargo.toml | 3 + rust/pygpukit-core/Cargo.toml | 3 + rust/pygpukit-core/src/dispatch/mod.rs | 5 + .../src/dispatch/persistent_cache.rs | 823 ++++++++++++++++++ rust/pygpukit-python/Cargo.toml | 1 + rust/pygpukit-python/src/dispatch.rs | 433 ++++++++- 7 files changed, 1483 insertions(+), 3 deletions(-) create mode 100644 rust/pygpukit-core/src/dispatch/persistent_cache.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 0ed206e..da2c16a 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -26,12 +26,44 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys", +] + [[package]] name = "equivalent" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "getrandom" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "getrandom" version = "0.3.4" @@ -75,6 +107,12 @@ dependencies = [ "rustversion", ] +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + [[package]] name = "js-sys" version = "0.3.83" @@ -91,6 +129,16 @@ version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" +[[package]] +name = "libredox" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb" +dependencies = [ + "bitflags", + "libc", +] + [[package]] name = "lock_api" version = "0.4.14" @@ -110,6 +158,12 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + [[package]] name = "memoffset" version = "0.9.1" @@ -182,6 +236,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + [[package]] name = "parking_lot" version = "0.12.5" @@ -233,14 +293,18 @@ dependencies = [ name = "pygpukit-core" version = "0.2.0" dependencies = [ + "dirs", "indexmap", "parking_lot", + "serde", + "serde_json", ] [[package]] name = "pygpukit-python" version = "0.2.0" dependencies = [ + "dirs", "numpy", "parking_lot", "pygpukit-core", @@ -341,6 +405,17 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_users" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +dependencies = [ + "getrandom 0.2.16", + "libredox", + "thiserror", +] + [[package]] name = "rustc-hash" version = "2.1.1" @@ -353,12 +428,61 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.145" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", + "serde_core", +] + [[package]] name = "smallvec" version = "1.15.1" @@ -382,6 +506,26 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-ident" version = "1.0.22" @@ -400,11 +544,17 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" dependencies = [ - "getrandom", + "getrandom 0.3.4", "js-sys", "wasm-bindgen", ] +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + [[package]] name = "wasip2" version = "1.0.1+wasi-0.2.4" @@ -465,6 +615,72 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + [[package]] name = "wit-bindgen" version = "0.46.0" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 72e2506..2802dfd 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -14,3 +14,6 @@ numpy = "0.23" parking_lot = "0.12" indexmap = "2.7" uuid = { version = "1.11", features = ["v4"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +dirs = "5.0" diff --git a/rust/pygpukit-core/Cargo.toml b/rust/pygpukit-core/Cargo.toml index 5da9090..3f154ca 100644 --- a/rust/pygpukit-core/Cargo.toml +++ b/rust/pygpukit-core/Cargo.toml @@ -8,3 +8,6 @@ description = "Core Rust implementation for PyGPUkit memory pool and scheduler" [dependencies] parking_lot.workspace = true indexmap.workspace = true +serde.workspace = true +serde_json.workspace = true +dirs.workspace = true diff --git a/rust/pygpukit-core/src/dispatch/mod.rs b/rust/pygpukit-core/src/dispatch/mod.rs index 28231af..efceecd 100644 --- a/rust/pygpukit-core/src/dispatch/mod.rs +++ b/rust/pygpukit-core/src/dispatch/mod.rs @@ -15,6 +15,7 @@ mod controller; mod pacing; mod slicing; mod cache; +mod persistent_cache; pub use controller::{KernelDispatcher, KernelLaunchRequest, KernelState, DispatchStats, LaunchConfig}; pub use pacing::{ @@ -26,3 +27,7 @@ pub use slicing::{ pub use cache::{ KernelCache, CacheConfig, CachedKernel, CompileOptions, CacheStats, }; +pub use persistent_cache::{ + PersistentCache, PersistentCacheConfig, PersistentCacheStats, + PersistentEntry, ArchFingerprint, CacheIndex, CacheError, +}; diff --git a/rust/pygpukit-core/src/dispatch/persistent_cache.rs b/rust/pygpukit-core/src/dispatch/persistent_cache.rs new file mode 100644 index 0000000..d576e01 --- /dev/null +++ b/rust/pygpukit-core/src/dispatch/persistent_cache.rs @@ -0,0 +1,823 @@ +//! Persistent Kernel Cache +//! +//! Extends the in-memory kernel cache with disk persistence. +//! Compiled PTX is saved to `~/.pygpukit/kernel_cache/` for reuse across sessions. + +use std::collections::HashMap; +use std::fs::{self, File}; +use std::io::{BufReader, BufWriter}; +use std::path::PathBuf; +use std::hash::{Hash, Hasher}; +use std::collections::hash_map::DefaultHasher; + +use serde::{Deserialize, Serialize}; + +/// GPU architecture fingerprint for cache key generation +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ArchFingerprint { + /// SM version (e.g., 86 for SM 8.6) + pub sm_version: u32, + /// Total global memory in bytes + pub global_memory: u64, + /// Shared memory per SM in bytes + pub shared_memory_per_sm: u32, + /// Max registers per block + pub max_registers_per_block: u32, + /// L2 cache size in bytes + pub l2_cache_size: u32, + /// CUDA driver version (MAJOR*1000 + MINOR*10) + pub driver_version: u32, +} + +impl ArchFingerprint { + /// Create a new architecture fingerprint + pub fn new( + sm_version: u32, + global_memory: u64, + shared_memory_per_sm: u32, + max_registers_per_block: u32, + l2_cache_size: u32, + driver_version: u32, + ) -> Self { + Self { + sm_version, + global_memory, + shared_memory_per_sm, + max_registers_per_block, + l2_cache_size, + driver_version, + } + } + + /// Compute hash of fingerprint + pub fn hash(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + Hash::hash(self, &mut hasher); + hasher.finish() + } + + /// Check if this fingerprint is compatible with another + /// (same SM version and driver version are minimum requirements) + pub fn is_compatible(&self, other: &Self) -> bool { + self.sm_version == other.sm_version && self.driver_version == other.driver_version + } +} + +impl Default for ArchFingerprint { + fn default() -> Self { + Self { + sm_version: 80, + global_memory: 0, + shared_memory_per_sm: 0, + max_registers_per_block: 0, + l2_cache_size: 0, + driver_version: 11000, + } + } +} + +/// Persistent cache entry (stored on disk) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PersistentEntry { + /// Source code hash + pub source_hash: u64, + /// Kernel name + pub name: String, + /// Compile options hash + pub options_hash: u64, + /// Architecture fingerprint + pub arch_fingerprint: ArchFingerprint, + /// PTX code + pub ptx: String, + /// Creation timestamp (Unix epoch) + pub created_at: f64, + /// Last access timestamp + pub last_access: f64, + /// Access count + pub access_count: usize, +} + +impl PersistentEntry { + /// Create a new entry + pub fn new( + source_hash: u64, + name: String, + options_hash: u64, + arch_fingerprint: ArchFingerprint, + ptx: String, + ) -> Self { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0); + + Self { + source_hash, + name, + options_hash, + arch_fingerprint, + ptx, + created_at: now, + last_access: now, + access_count: 1, + } + } + + /// Touch to update access time + pub fn touch(&mut self) { + self.last_access = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0); + self.access_count += 1; + } + + /// Get PTX size in bytes + pub fn ptx_size(&self) -> usize { + self.ptx.len() + } +} + +/// Cache index (stored separately for quick lookup) +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct CacheIndex { + /// Version of the cache format + pub version: u32, + /// Architecture fingerprint this index was built for + pub arch_fingerprint: ArchFingerprint, + /// Map of cache key to filename + pub entries: HashMap, + /// Total size of all cached PTX + pub total_size: usize, + /// Last cleanup timestamp + pub last_cleanup: f64, +} + +impl CacheIndex { + /// Current cache format version + pub const CURRENT_VERSION: u32 = 1; + + /// Create a new index + pub fn new(arch_fingerprint: ArchFingerprint) -> Self { + Self { + version: Self::CURRENT_VERSION, + arch_fingerprint, + entries: HashMap::new(), + total_size: 0, + last_cleanup: 0.0, + } + } + + /// Check if index is compatible with current arch + pub fn is_compatible(&self, arch: &ArchFingerprint) -> bool { + self.version == Self::CURRENT_VERSION && self.arch_fingerprint.is_compatible(arch) + } +} + +/// Persistent cache configuration +#[derive(Debug, Clone)] +pub struct PersistentCacheConfig { + /// Cache directory path + pub cache_dir: PathBuf, + /// Maximum total cache size in bytes + pub max_size: usize, + /// Maximum number of entries + pub max_entries: usize, + /// Enable auto-cleanup on startup + pub auto_cleanup: bool, + /// Entry TTL in seconds (0 = infinite) + pub ttl_seconds: f64, +} + +impl Default for PersistentCacheConfig { + fn default() -> Self { + let cache_dir = dirs::home_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join(".pygpukit") + .join("kernel_cache"); + + Self { + cache_dir, + max_size: 512 * 1024 * 1024, // 512MB + max_entries: 4096, + auto_cleanup: true, + ttl_seconds: 0.0, // No TTL by default + } + } +} + +/// Cache statistics +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct PersistentCacheStats { + /// Number of entries + pub entries: usize, + /// Total size in bytes + pub total_size: usize, + /// Cache hits + pub hits: usize, + /// Cache misses + pub misses: usize, + /// Evictions + pub evictions: usize, + /// Load errors + pub load_errors: usize, + /// Save errors + pub save_errors: usize, +} + +impl PersistentCacheStats { + /// Calculate hit rate + pub fn hit_rate(&self) -> f64 { + let total = self.hits + self.misses; + if total > 0 { + self.hits as f64 / total as f64 + } else { + 0.0 + } + } +} + +/// Persistent kernel cache +pub struct PersistentCache { + config: PersistentCacheConfig, + arch_fingerprint: ArchFingerprint, + index: CacheIndex, + stats: PersistentCacheStats, + initialized: bool, +} + +impl PersistentCache { + /// Create a new persistent cache + pub fn new(config: PersistentCacheConfig, arch_fingerprint: ArchFingerprint) -> Self { + let index = CacheIndex::new(arch_fingerprint.clone()); + Self { + config, + arch_fingerprint, + index, + stats: PersistentCacheStats::default(), + initialized: false, + } + } + + /// Create with defaults + pub fn with_defaults(arch_fingerprint: ArchFingerprint) -> Self { + Self::new(PersistentCacheConfig::default(), arch_fingerprint) + } + + /// Initialize the cache (load index, create directories) + pub fn initialize(&mut self) -> Result<(), CacheError> { + if self.initialized { + return Ok(()); + } + + // Create cache directory + fs::create_dir_all(&self.config.cache_dir).map_err(|e| CacheError::Io(e.to_string()))?; + + // Load or create index + let index_path = self.index_path(); + if index_path.exists() { + match self.load_index() { + Ok(index) => { + if index.is_compatible(&self.arch_fingerprint) { + self.index = index; + self.stats.entries = self.index.entries.len(); + self.stats.total_size = self.index.total_size; + } else { + // Incompatible index - clear cache + self.clear()?; + } + } + Err(_) => { + // Corrupted index - clear cache + self.clear()?; + } + } + } + + // Auto cleanup if enabled + if self.config.auto_cleanup { + let _ = self.cleanup(); + } + + self.initialized = true; + Ok(()) + } + + /// Get cache directory path + pub fn cache_dir(&self) -> &PathBuf { + &self.config.cache_dir + } + + /// Get index file path + fn index_path(&self) -> PathBuf { + self.config.cache_dir.join("index.json") + } + + /// Get entry file path + fn entry_path(&self, key: u64) -> PathBuf { + self.config.cache_dir.join(format!("{:016x}.ptx.json", key)) + } + + /// Compute cache key + pub fn compute_key(source_hash: u64, name: &str, options_hash: u64, arch_hash: u64) -> u64 { + let mut hasher = DefaultHasher::new(); + source_hash.hash(&mut hasher); + name.hash(&mut hasher); + options_hash.hash(&mut hasher); + arch_hash.hash(&mut hasher); + hasher.finish() + } + + /// Hash source code + pub fn hash_source(source: &str) -> u64 { + let mut hasher = DefaultHasher::new(); + source.hash(&mut hasher); + hasher.finish() + } + + /// Hash compile options + pub fn hash_options(options: &[String]) -> u64 { + let mut hasher = DefaultHasher::new(); + for opt in options { + opt.hash(&mut hasher); + } + hasher.finish() + } + + /// Get cached entry + pub fn get(&mut self, key: u64) -> Result, CacheError> { + if !self.initialized { + self.initialize()?; + } + + if !self.index.entries.contains_key(&key) { + self.stats.misses += 1; + return Ok(None); + } + + // Load entry from disk + let entry_path = self.entry_path(key); + match self.load_entry(&entry_path) { + Ok(mut entry) => { + // Check TTL + if self.config.ttl_seconds > 0.0 { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0); + if now - entry.created_at > self.config.ttl_seconds { + // TTL expired - remove + let _ = self.remove(key); + self.stats.misses += 1; + return Ok(None); + } + } + + // Update access time + entry.touch(); + let _ = self.save_entry(key, &entry); + + self.stats.hits += 1; + Ok(Some(entry)) + } + Err(_) => { + // Remove corrupted entry + self.index.entries.remove(&key); + let _ = self.save_index(); + self.stats.misses += 1; + self.stats.load_errors += 1; + Ok(None) + } + } + } + + /// Insert entry + pub fn insert( + &mut self, + source: &str, + name: &str, + options: &[String], + ptx: String, + ) -> Result { + if !self.initialized { + self.initialize()?; + } + + let source_hash = Self::hash_source(source); + let options_hash = Self::hash_options(options); + let arch_hash = self.arch_fingerprint.hash(); + let key = Self::compute_key(source_hash, name, options_hash, arch_hash); + + // Check if already exists + if self.index.entries.contains_key(&key) { + // Update access time + if let Ok(Some(mut entry)) = self.get(key) { + entry.touch(); + let _ = self.save_entry(key, &entry); + } + return Ok(key); + } + + // Evict if necessary + let ptx_size = ptx.len(); + self.evict_if_needed(ptx_size)?; + + // Create entry + let entry = PersistentEntry::new( + source_hash, + name.to_string(), + options_hash, + self.arch_fingerprint.clone(), + ptx, + ); + + // Save entry + self.save_entry(key, &entry)?; + + // Update index + self.index.entries.insert(key, format!("{:016x}.ptx.json", key)); + self.index.total_size += ptx_size; + self.save_index()?; + + self.stats.entries = self.index.entries.len(); + self.stats.total_size = self.index.total_size; + + Ok(key) + } + + /// Remove entry + pub fn remove(&mut self, key: u64) -> Result { + if !self.index.entries.contains_key(&key) { + return Ok(false); + } + + // Get entry size before removing + let entry_path = self.entry_path(key); + let size = if let Ok(entry) = self.load_entry(&entry_path) { + entry.ptx_size() + } else { + 0 + }; + + // Remove file + let _ = fs::remove_file(&entry_path); + + // Update index + self.index.entries.remove(&key); + self.index.total_size = self.index.total_size.saturating_sub(size); + self.save_index()?; + + self.stats.entries = self.index.entries.len(); + self.stats.total_size = self.index.total_size; + + Ok(true) + } + + /// Evict entries if needed + fn evict_if_needed(&mut self, new_size: usize) -> Result<(), CacheError> { + // Evict by entry count + while self.index.entries.len() >= self.config.max_entries { + if !self.evict_lru()? { + break; // No more entries to evict + } + } + + // Evict by size + while self.index.total_size + new_size > self.config.max_size + && !self.index.entries.is_empty() + { + if !self.evict_lru()? { + break; // No more entries to evict + } + } + + Ok(()) + } + + /// Evict least recently used entry + fn evict_lru(&mut self) -> Result { + // Load all entries to find LRU + let mut lru_key: Option = None; + let mut lru_time = f64::MAX; + + // Collect keys to avoid borrow issues + let keys: Vec = self.index.entries.keys().copied().collect(); + + for key in keys { + let entry_path = self.entry_path(key); + if let Ok(entry) = self.load_entry(&entry_path) { + if entry.last_access < lru_time { + lru_time = entry.last_access; + lru_key = Some(key); + } + } else { + // If we can't load the entry, it's a candidate for removal (orphaned index entry) + lru_key = Some(key); + break; + } + } + + if let Some(key) = lru_key { + self.remove(key)?; + self.stats.evictions += 1; + Ok(true) + } else { + Ok(false) // No entry to evict + } + } + + /// Cleanup expired entries and orphaned files + pub fn cleanup(&mut self) -> Result { + if !self.initialized { + self.initialize()?; + } + + let mut removed = 0; + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0); + + // Collect expired entries + let mut to_remove = Vec::new(); + + if self.config.ttl_seconds > 0.0 { + for &key in self.index.entries.keys() { + let entry_path = self.entry_path(key); + if let Ok(entry) = self.load_entry(&entry_path) { + if now - entry.created_at > self.config.ttl_seconds { + to_remove.push(key); + } + } + } + } + + // Remove expired entries + for key in to_remove { + self.remove(key)?; + removed += 1; + } + + // Update last cleanup time + self.index.last_cleanup = now; + self.save_index()?; + + Ok(removed) + } + + /// Clear all cache + pub fn clear(&mut self) -> Result<(), CacheError> { + // Remove all entry files + if self.config.cache_dir.exists() { + for entry in fs::read_dir(&self.config.cache_dir) + .map_err(|e| CacheError::Io(e.to_string()))? + { + if let Ok(entry) = entry { + let path = entry.path(); + if path.extension().map_or(false, |ext| ext == "json") { + let _ = fs::remove_file(path); + } + } + } + } + + // Reset index + self.index = CacheIndex::new(self.arch_fingerprint.clone()); + self.stats = PersistentCacheStats::default(); + + Ok(()) + } + + /// Get statistics + pub fn stats(&self) -> &PersistentCacheStats { + &self.stats + } + + /// Get number of entries + pub fn len(&self) -> usize { + self.index.entries.len() + } + + /// Check if empty + pub fn is_empty(&self) -> bool { + self.index.entries.is_empty() + } + + /// Check if key exists + pub fn contains(&self, key: u64) -> bool { + self.index.entries.contains_key(&key) + } + + /// Load index from disk + fn load_index(&self) -> Result { + let file = File::open(self.index_path()).map_err(|e| CacheError::Io(e.to_string()))?; + let reader = BufReader::new(file); + serde_json::from_reader(reader).map_err(|e| CacheError::Serialization(e.to_string())) + } + + /// Save index to disk + fn save_index(&self) -> Result<(), CacheError> { + let file = + File::create(self.index_path()).map_err(|e| CacheError::Io(e.to_string()))?; + let writer = BufWriter::new(file); + serde_json::to_writer_pretty(writer, &self.index) + .map_err(|e| CacheError::Serialization(e.to_string())) + } + + /// Load entry from disk + fn load_entry(&self, path: &PathBuf) -> Result { + let file = File::open(path).map_err(|e| CacheError::Io(e.to_string()))?; + let reader = BufReader::new(file); + serde_json::from_reader(reader).map_err(|e| CacheError::Serialization(e.to_string())) + } + + /// Save entry to disk + fn save_entry(&mut self, key: u64, entry: &PersistentEntry) -> Result<(), CacheError> { + let path = self.entry_path(key); + let file = File::create(&path).map_err(|e| CacheError::Io(e.to_string()))?; + let writer = BufWriter::new(file); + serde_json::to_writer(writer, entry).map_err(|e| { + self.stats.save_errors += 1; + CacheError::Serialization(e.to_string()) + }) + } +} + +/// Cache error types +#[derive(Debug, Clone)] +pub enum CacheError { + /// I/O error + Io(String), + /// Serialization error + Serialization(String), + /// Not initialized + NotInitialized, +} + +impl std::fmt::Display for CacheError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CacheError::Io(s) => write!(f, "I/O error: {}", s), + CacheError::Serialization(s) => write!(f, "Serialization error: {}", s), + CacheError::NotInitialized => write!(f, "Cache not initialized"), + } + } +} + +impl std::error::Error for CacheError {} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + use std::sync::atomic::{AtomicU64, Ordering}; + + // Unique counter for test directories + static TEST_COUNTER: AtomicU64 = AtomicU64::new(0); + + fn test_config() -> PersistentCacheConfig { + // Generate unique directory for each test + let id = TEST_COUNTER.fetch_add(1, Ordering::Relaxed); + let thread_id = std::thread::current().id(); + let temp_dir = env::temp_dir().join(format!( + "pygpukit_test_cache_{:?}_{}", + thread_id, + id + )); + PersistentCacheConfig { + cache_dir: temp_dir, + max_size: 1024 * 1024, // 1MB + max_entries: 10, + auto_cleanup: false, + ttl_seconds: 0.0, + } + } + + fn test_arch() -> ArchFingerprint { + ArchFingerprint::new(86, 24 * 1024 * 1024 * 1024, 100 * 1024, 65536, 6 * 1024 * 1024, 12040) + } + + #[test] + fn test_arch_fingerprint() { + let arch1 = test_arch(); + let arch2 = test_arch(); + + assert!(arch1.is_compatible(&arch2)); + + let arch3 = ArchFingerprint::new(80, 0, 0, 0, 0, 12040); + assert!(!arch1.is_compatible(&arch3)); + } + + #[test] + fn test_cache_key() { + let source_hash = PersistentCache::hash_source("__global__ void foo() {}"); + let options_hash = PersistentCache::hash_options(&["-O3".to_string()]); + let arch_hash = test_arch().hash(); + + let key1 = PersistentCache::compute_key(source_hash, "foo", options_hash, arch_hash); + let key2 = PersistentCache::compute_key(source_hash, "foo", options_hash, arch_hash); + assert_eq!(key1, key2); + + let key3 = PersistentCache::compute_key(source_hash, "bar", options_hash, arch_hash); + assert_ne!(key1, key3); + } + + #[test] + fn test_persistent_cache_basic() { + let config = test_config(); + let arch = test_arch(); + + // Clean up first + let _ = fs::remove_dir_all(&config.cache_dir); + + let mut cache = PersistentCache::new(config.clone(), arch); + cache.initialize().unwrap(); + + // Insert + let key = cache + .insert( + "__global__ void test() {}", + "test", + &[], + "// PTX code".to_string(), + ) + .unwrap(); + + assert!(cache.contains(key)); + assert_eq!(cache.len(), 1); + + // Get + let entry = cache.get(key).unwrap().unwrap(); + assert_eq!(entry.name, "test"); + assert_eq!(entry.ptx, "// PTX code"); + + // Clean up + let _ = fs::remove_dir_all(&config.cache_dir); + } + + #[test] + fn test_persistent_cache_eviction() { + let mut config = test_config(); + config.max_entries = 2; + let arch = test_arch(); + + // Clean up first + let _ = fs::remove_dir_all(&config.cache_dir); + + let mut cache = PersistentCache::new(config.clone(), arch.clone()); + cache.initialize().unwrap(); + + // Insert 3 entries + cache + .insert("src1", "k1", &[], "ptx1".to_string()) + .unwrap(); + cache + .insert("src2", "k2", &[], "ptx2".to_string()) + .unwrap(); + + // Access k2 to make k1 the LRU + let key2 = PersistentCache::compute_key( + PersistentCache::hash_source("src2"), + "k2", + PersistentCache::hash_options(&[]), + arch.hash(), + ); + cache.get(key2).unwrap(); + + // Insert third - should evict k1 + cache + .insert("src3", "k3", &[], "ptx3".to_string()) + .unwrap(); + + assert_eq!(cache.len(), 2); + assert!(cache.stats().evictions >= 1); + + // Clean up + let _ = fs::remove_dir_all(&config.cache_dir); + } + + #[test] + fn test_persistent_cache_clear() { + let config = test_config(); + let arch = test_arch(); + + // Clean up first + let _ = fs::remove_dir_all(&config.cache_dir); + + let mut cache = PersistentCache::new(config.clone(), arch); + cache.initialize().unwrap(); + + cache + .insert("src1", "k1", &[], "ptx1".to_string()) + .unwrap(); + cache + .insert("src2", "k2", &[], "ptx2".to_string()) + .unwrap(); + + cache.clear().unwrap(); + assert!(cache.is_empty()); + + // Clean up + let _ = fs::remove_dir_all(&config.cache_dir); + } +} diff --git a/rust/pygpukit-python/Cargo.toml b/rust/pygpukit-python/Cargo.toml index 186f86f..e964560 100644 --- a/rust/pygpukit-python/Cargo.toml +++ b/rust/pygpukit-python/Cargo.toml @@ -15,3 +15,4 @@ pyo3.workspace = true numpy.workspace = true parking_lot.workspace = true uuid.workspace = true +dirs.workspace = true diff --git a/rust/pygpukit-python/src/dispatch.rs b/rust/pygpukit-python/src/dispatch.rs index 24251db..c05cd25 100644 --- a/rust/pygpukit-python/src/dispatch.rs +++ b/rust/pygpukit-python/src/dispatch.rs @@ -1,12 +1,15 @@ //! Python bindings for the kernel dispatch controller use pyo3::prelude::*; +use pyo3::exceptions::PyRuntimeError; use std::collections::HashMap; +use std::path::PathBuf; use pygpukit_core::dispatch::{ KernelDispatcher, KernelLaunchRequest, KernelState, DispatchStats, LaunchConfig, KernelPacingEngine, PacingConfig, PacingDecision, PacingStats, StreamPacingStats, - SliceScheduler, SliceConfig, SlicedKernel, KernelSlice, SliceInfo, SliceStats, + SliceScheduler, SliceConfig, KernelSlice, SliceInfo, SliceStats, KernelCache, CacheConfig, CachedKernel, CompileOptions, CacheStats, + PersistentCache, PersistentCacheConfig, PersistentCacheStats, PersistentEntry, ArchFingerprint, }; /// Python wrapper for KernelState enum @@ -1346,6 +1349,426 @@ impl PyKernelCache { } } +// ============================================================================= +// Persistent Kernel Cache Types +// ============================================================================= + +/// Architecture fingerprint for cache keys +/// +/// Contains GPU characteristics that affect PTX validity. +#[pyclass(name = "ArchFingerprint")] +#[derive(Clone)] +pub struct PyArchFingerprint { + inner: ArchFingerprint, +} + +#[pymethods] +impl PyArchFingerprint { + /// Create a new architecture fingerprint + /// + /// Args: + /// sm_version: SM version (e.g., 86 for RTX 3090) + /// global_memory: Global memory in bytes + /// shared_memory_per_sm: Shared memory per SM in bytes + /// max_registers_per_block: Max registers per block + /// l2_cache_size: L2 cache size in bytes + /// driver_version: CUDA driver version + #[new] + fn new( + sm_version: u32, + global_memory: u64, + shared_memory_per_sm: u32, + max_registers_per_block: u32, + l2_cache_size: u32, + driver_version: u32, + ) -> Self { + Self { + inner: ArchFingerprint::new( + sm_version, + global_memory, + shared_memory_per_sm, + max_registers_per_block, + l2_cache_size, + driver_version, + ), + } + } + + #[getter] + fn sm_version(&self) -> u32 { + self.inner.sm_version + } + + #[getter] + fn global_memory(&self) -> u64 { + self.inner.global_memory + } + + #[getter] + fn shared_memory_per_sm(&self) -> u32 { + self.inner.shared_memory_per_sm + } + + #[getter] + fn max_registers_per_block(&self) -> u32 { + self.inner.max_registers_per_block + } + + #[getter] + fn l2_cache_size(&self) -> u32 { + self.inner.l2_cache_size + } + + #[getter] + fn driver_version(&self) -> u32 { + self.inner.driver_version + } + + /// Compute hash for this fingerprint + fn hash(&self) -> u64 { + self.inner.hash() + } + + /// Check if this fingerprint is compatible with another + fn is_compatible(&self, other: &PyArchFingerprint) -> bool { + self.inner.is_compatible(&other.inner) + } + + fn __repr__(&self) -> String { + format!( + "ArchFingerprint(sm={}, memory={}GB, driver={})", + self.inner.sm_version, + self.inner.global_memory / (1024 * 1024 * 1024), + self.inner.driver_version + ) + } +} + +/// Persistent cache configuration +#[pyclass(name = "PersistentCacheConfig")] +#[derive(Clone)] +pub struct PyPersistentCacheConfig { + inner: PersistentCacheConfig, +} + +#[pymethods] +impl PyPersistentCacheConfig { + /// Create a new persistent cache configuration + /// + /// Args: + /// cache_dir: Directory path for cache storage (default: ~/.pygpukit/cache) + /// max_size: Maximum total cache size in bytes (default: 512MB) + /// max_entries: Maximum number of cached entries (default: 1000) + /// auto_cleanup: Enable automatic cleanup (default: True) + /// ttl_seconds: Time-to-live in seconds, 0.0 for unlimited (default: 0.0) + #[new] + #[pyo3(signature = (cache_dir=None, max_size=536870912, max_entries=1000, auto_cleanup=true, ttl_seconds=0.0))] + fn new( + cache_dir: Option, + max_size: usize, + max_entries: usize, + auto_cleanup: bool, + ttl_seconds: f64, + ) -> Self { + let cache_dir = cache_dir + .map(PathBuf::from) + .unwrap_or_else(|| { + dirs::home_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join(".pygpukit") + .join("cache") + }); + + Self { + inner: PersistentCacheConfig { + cache_dir, + max_size, + max_entries, + auto_cleanup, + ttl_seconds, + }, + } + } + + #[getter] + fn cache_dir(&self) -> String { + self.inner.cache_dir.to_string_lossy().to_string() + } + + #[getter] + fn max_size(&self) -> usize { + self.inner.max_size + } + + #[getter] + fn max_entries(&self) -> usize { + self.inner.max_entries + } + + #[getter] + fn auto_cleanup(&self) -> bool { + self.inner.auto_cleanup + } + + #[getter] + fn ttl_seconds(&self) -> f64 { + self.inner.ttl_seconds + } + + fn __repr__(&self) -> String { + format!( + "PersistentCacheConfig(dir='{}', max_size={}MB, max_entries={})", + self.inner.cache_dir.display(), + self.inner.max_size / (1024 * 1024), + self.inner.max_entries + ) + } +} + +/// Persistent cache statistics +#[pyclass(name = "PersistentCacheStats")] +#[derive(Clone)] +pub struct PyPersistentCacheStats { + inner: PersistentCacheStats, +} + +#[pymethods] +impl PyPersistentCacheStats { + #[getter] + fn entries(&self) -> usize { + self.inner.entries + } + + #[getter] + fn total_size(&self) -> usize { + self.inner.total_size + } + + #[getter] + fn hits(&self) -> usize { + self.inner.hits + } + + #[getter] + fn misses(&self) -> usize { + self.inner.misses + } + + #[getter] + fn evictions(&self) -> usize { + self.inner.evictions + } + + #[getter] + fn load_errors(&self) -> usize { + self.inner.load_errors + } + + #[getter] + fn save_errors(&self) -> usize { + self.inner.save_errors + } + + /// Calculate hit rate (0.0 - 1.0) + fn hit_rate(&self) -> f64 { + self.inner.hit_rate() + } + + fn __repr__(&self) -> String { + format!( + "PersistentCacheStats(entries={}, size={}KB, hit_rate={:.1}%)", + self.inner.entries, + self.inner.total_size / 1024, + self.inner.hit_rate() * 100.0 + ) + } +} + +/// Persistent cache entry +#[pyclass(name = "PersistentEntry")] +#[derive(Clone)] +pub struct PyPersistentEntry { + inner: PersistentEntry, +} + +#[pymethods] +impl PyPersistentEntry { + #[getter] + fn name(&self) -> &str { + &self.inner.name + } + + #[getter] + fn ptx(&self) -> &str { + &self.inner.ptx + } + + #[getter] + fn source_hash(&self) -> u64 { + self.inner.source_hash + } + + #[getter] + fn options_hash(&self) -> u64 { + self.inner.options_hash + } + + #[getter] + fn created_at(&self) -> f64 { + self.inner.created_at + } + + #[getter] + fn last_access(&self) -> f64 { + self.inner.last_access + } + + #[getter] + fn access_count(&self) -> usize { + self.inner.access_count + } + + /// Get PTX size in bytes + fn ptx_size(&self) -> usize { + self.inner.ptx_size() + } + + fn __repr__(&self) -> String { + format!( + "PersistentEntry(name='{}', ptx_size={}KB, accesses={})", + self.inner.name, + self.inner.ptx_size() / 1024, + self.inner.access_count + ) + } +} + +/// Persistent kernel cache +/// +/// Caches compiled PTX to disk for fast startup across sessions. +/// Uses GPU architecture fingerprinting to ensure cache validity. +#[pyclass(name = "PersistentCache")] +pub struct PyPersistentCache { + inner: PersistentCache, +} + +#[pymethods] +impl PyPersistentCache { + /// Create a new persistent cache + /// + /// Args: + /// config: Cache configuration + /// arch: Architecture fingerprint for the current GPU + #[new] + fn new(config: PyPersistentCacheConfig, arch: PyArchFingerprint) -> Self { + Self { + inner: PersistentCache::new(config.inner, arch.inner), + } + } + + /// Initialize the cache (creates directories, loads index) + fn initialize(&mut self) -> PyResult<()> { + self.inner.initialize() + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + + /// Compute cache key from components + #[staticmethod] + fn compute_key(source_hash: u64, name: &str, options_hash: u64, arch_hash: u64) -> u64 { + PersistentCache::compute_key(source_hash, name, options_hash, arch_hash) + } + + /// Compute hash for source code + #[staticmethod] + fn hash_source(source: &str) -> u64 { + PersistentCache::hash_source(source) + } + + /// Compute hash for compile options + #[staticmethod] + fn hash_options(options: Vec) -> u64 { + PersistentCache::hash_options(&options) + } + + /// Get entry by key + fn get(&mut self, key: u64) -> PyResult> { + self.inner.get(key) + .map(|opt| opt.map(|e| PyPersistentEntry { inner: e })) + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + + /// Insert a new entry + /// + /// Args: + /// source: CUDA source code + /// name: Kernel name + /// options: Compile options + /// ptx: Compiled PTX code + /// + /// Returns: + /// Cache key for the entry + fn insert( + &mut self, + source: &str, + name: &str, + options: Vec, + ptx: String, + ) -> PyResult { + self.inner.insert(source, name, &options, ptx) + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + + /// Remove an entry by key + fn remove(&mut self, key: u64) -> PyResult { + self.inner.remove(key) + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + + /// Clear all entries + fn clear(&mut self) -> PyResult<()> { + self.inner.clear() + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + + /// Run cleanup (remove expired entries) + fn cleanup(&mut self) -> PyResult { + self.inner.cleanup() + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + } + + /// Check if key exists + fn contains(&self, key: u64) -> bool { + self.inner.contains(key) + } + + /// Get number of entries + fn __len__(&self) -> usize { + self.inner.len() + } + + /// Check if empty + fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Get statistics + fn stats(&self) -> PyPersistentCacheStats { + PyPersistentCacheStats { + inner: self.inner.stats().clone(), + } + } + + fn __repr__(&self) -> String { + let stats = self.inner.stats(); + format!( + "PersistentCache(entries={}, size={}KB, hit_rate={:.1}%)", + stats.entries, + stats.total_size / 1024, + stats.hit_rate() * 100.0 + ) + } +} + /// Register dispatch module pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; @@ -1365,11 +1788,17 @@ pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - // Cache + // Cache (in-memory) m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + // Persistent cache (disk) + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } From 8b0f8d2f91b41e8c870b89a43e7b5b48aa113292 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Sun, 14 Dec 2025 23:51:38 +0900 Subject: [PATCH 03/24] feat(ops): add sub, div, exp, log, relu elementwise operations (#59) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements missing elementwise operations: - Binary ops: sub (a - b), div (a / b) - all dtypes - Unary ops: exp, log, relu - float32/float64 only Each operation includes: - CUDA kernel implementations (f32, f64, i32, i64 for binary) - pybind11 bindings with in-place variants - Python wrappers with CPU fallback - dtype validation for unary ops 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- native/bindings/ops_bindings.cpp | 60 +++++- native/ops/basic.cu | 350 +++++++++++++++++++++++++++++++ native/ops/basic.cuh | 42 +++- src/pygpukit/__init__.py | 7 +- src/pygpukit/ops/basic.py | 212 +++++++++++++++++++ 5 files changed, 662 insertions(+), 9 deletions(-) diff --git a/native/bindings/ops_bindings.cpp b/native/bindings/ops_bindings.cpp index 539a41f..97ec33c 100644 --- a/native/bindings/ops_bindings.cpp +++ b/native/bindings/ops_bindings.cpp @@ -7,7 +7,11 @@ namespace py = pybind11; using namespace pygpukit; void init_ops_bindings(py::module_& m) { - // Element-wise operations + // ======================================================================== + // Binary Element-wise operations + // ======================================================================== + + // Add m.def("add", py::overload_cast(&ops::add), py::arg("a"), py::arg("b"), "Element-wise addition of two GPUArrays"); @@ -16,6 +20,16 @@ void init_ops_bindings(py::module_& m) { py::arg("a"), py::arg("b"), py::arg("out"), "Element-wise addition with output array"); + // Sub + m.def("sub", py::overload_cast(&ops::sub), + py::arg("a"), py::arg("b"), + "Element-wise subtraction of two GPUArrays"); + + m.def("sub_", py::overload_cast(&ops::sub), + py::arg("a"), py::arg("b"), py::arg("out"), + "Element-wise subtraction with output array"); + + // Mul m.def("mul", py::overload_cast(&ops::mul), py::arg("a"), py::arg("b"), "Element-wise multiplication of two GPUArrays"); @@ -24,6 +38,50 @@ void init_ops_bindings(py::module_& m) { py::arg("a"), py::arg("b"), py::arg("out"), "Element-wise multiplication with output array"); + // Div + m.def("div", py::overload_cast(&ops::div), + py::arg("a"), py::arg("b"), + "Element-wise division of two GPUArrays"); + + m.def("div_", py::overload_cast(&ops::div), + py::arg("a"), py::arg("b"), py::arg("out"), + "Element-wise division with output array"); + + // ======================================================================== + // Unary Element-wise operations (float only) + // ======================================================================== + + // Exp + m.def("exp", py::overload_cast(&ops::exp), + py::arg("a"), + "Element-wise exponential (float32/float64 only)"); + + m.def("exp_", py::overload_cast(&ops::exp), + py::arg("a"), py::arg("out"), + "Element-wise exponential with output array"); + + // Log + m.def("log", py::overload_cast(&ops::log), + py::arg("a"), + "Element-wise natural logarithm (float32/float64 only)"); + + m.def("log_", py::overload_cast(&ops::log), + py::arg("a"), py::arg("out"), + "Element-wise natural logarithm with output array"); + + // ReLU + m.def("relu", py::overload_cast(&ops::relu), + py::arg("a"), + "Element-wise ReLU: max(0, x) (float32/float64 only)"); + + m.def("relu_", py::overload_cast(&ops::relu), + py::arg("a"), py::arg("out"), + "Element-wise ReLU with output array"); + + // ======================================================================== + // Matrix operations + // ======================================================================== + m.def("matmul", py::overload_cast(&ops::matmul), py::arg("a"), py::arg("b"), "Matrix multiplication of two GPUArrays"); diff --git a/native/ops/basic.cu b/native/ops/basic.cu index 0d10add..47ed5b6 100644 --- a/native/ops/basic.cu +++ b/native/ops/basic.cu @@ -229,6 +229,356 @@ GPUArray mul(const GPUArray& a, const GPUArray& b) { return c; } +// ============================================================================ +// Sub kernels +// ============================================================================ + +__global__ void sub_f32_kernel(const float* a, const float* b, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] - b[idx]; + } +} + +__global__ void sub_f64_kernel(const double* a, const double* b, double* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] - b[idx]; + } +} + +__global__ void sub_i32_kernel(const int32_t* a, const int32_t* b, int32_t* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] - b[idx]; + } +} + +__global__ void sub_i64_kernel(const int64_t* a, const int64_t* b, int64_t* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] - b[idx]; + } +} + +void sub(const GPUArray& a, const GPUArray& b, GPUArray& c) { + validate_same_shape(a, b, "sub"); + validate_same_dtype(a, b, "sub"); + validate_same_shape(a, c, "sub"); + validate_same_dtype(a, c, "sub"); + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + sub_f32_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + n); + break; + case DataType::Float64: + sub_f64_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + n); + break; + case DataType::Int32: + sub_i32_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + n); + break; + case DataType::Int64: + sub_i64_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + n); + break; + } + + sync_and_check("sub kernel failed"); +} + +GPUArray sub(const GPUArray& a, const GPUArray& b) { + validate_same_shape(a, b, "sub"); + validate_same_dtype(a, b, "sub"); + + GPUArray c(a.shape(), a.dtype()); + sub(a, b, c); + return c; +} + +// ============================================================================ +// Div kernels +// ============================================================================ + +__global__ void div_f32_kernel(const float* a, const float* b, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] / b[idx]; + } +} + +__global__ void div_f64_kernel(const double* a, const double* b, double* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] / b[idx]; + } +} + +__global__ void div_i32_kernel(const int32_t* a, const int32_t* b, int32_t* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] / b[idx]; + } +} + +__global__ void div_i64_kernel(const int64_t* a, const int64_t* b, int64_t* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] / b[idx]; + } +} + +void div(const GPUArray& a, const GPUArray& b, GPUArray& c) { + validate_same_shape(a, b, "div"); + validate_same_dtype(a, b, "div"); + validate_same_shape(a, c, "div"); + validate_same_dtype(a, c, "div"); + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + div_f32_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + n); + break; + case DataType::Float64: + div_f64_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + n); + break; + case DataType::Int32: + div_i32_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + n); + break; + case DataType::Int64: + div_i64_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + n); + break; + } + + sync_and_check("div kernel failed"); +} + +GPUArray div(const GPUArray& a, const GPUArray& b) { + validate_same_shape(a, b, "div"); + validate_same_dtype(a, b, "div"); + + GPUArray c(a.shape(), a.dtype()); + div(a, b, c); + return c; +} + +// ============================================================================ +// Exp kernels (float only) +// ============================================================================ + +__global__ void exp_f32_kernel(const float* a, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = expf(a[idx]); + } +} + +__global__ void exp_f64_kernel(const double* a, double* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = ::exp(a[idx]); + } +} + +void exp(const GPUArray& a, GPUArray& c) { + validate_same_shape(a, c, "exp"); + validate_same_dtype(a, c, "exp"); + + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64) { + throw std::runtime_error("exp only supports float32 and float64"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + exp_f32_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), + n); + break; + case DataType::Float64: + exp_f64_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), + n); + break; + default: + break; + } + + sync_and_check("exp kernel failed"); +} + +GPUArray exp(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64) { + throw std::runtime_error("exp only supports float32 and float64"); + } + + GPUArray c(a.shape(), a.dtype()); + exp(a, c); + return c; +} + +// ============================================================================ +// Log kernels (float only) +// ============================================================================ + +__global__ void log_f32_kernel(const float* a, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = logf(a[idx]); + } +} + +__global__ void log_f64_kernel(const double* a, double* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = ::log(a[idx]); + } +} + +void log(const GPUArray& a, GPUArray& c) { + validate_same_shape(a, c, "log"); + validate_same_dtype(a, c, "log"); + + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64) { + throw std::runtime_error("log only supports float32 and float64"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + log_f32_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), + n); + break; + case DataType::Float64: + log_f64_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), + n); + break; + default: + break; + } + + sync_and_check("log kernel failed"); +} + +GPUArray log(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64) { + throw std::runtime_error("log only supports float32 and float64"); + } + + GPUArray c(a.shape(), a.dtype()); + log(a, c); + return c; +} + +// ============================================================================ +// ReLU kernels (float only) +// ============================================================================ + +__global__ void relu_f32_kernel(const float* a, float* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = fmaxf(0.0f, a[idx]); + } +} + +__global__ void relu_f64_kernel(const double* a, double* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = fmax(0.0, a[idx]); + } +} + +void relu(const GPUArray& a, GPUArray& c) { + validate_same_shape(a, c, "relu"); + validate_same_dtype(a, c, "relu"); + + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64) { + throw std::runtime_error("relu only supports float32 and float64"); + } + + size_t n = a.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + switch (a.dtype()) { + case DataType::Float32: + relu_f32_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), + n); + break; + case DataType::Float64: + relu_f64_kernel<<>>( + static_cast(a.data()), + static_cast(c.data()), + n); + break; + default: + break; + } + + sync_and_check("relu kernel failed"); +} + +GPUArray relu(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64) { + throw std::runtime_error("relu only supports float32 and float64"); + } + + GPUArray c(a.shape(), a.dtype()); + relu(a, c); + return c; +} + // ============================================================================ // Matmul kernels - Tiled with Shared Memory and Double Buffering // ============================================================================ diff --git a/native/ops/basic.cuh b/native/ops/basic.cuh index 9af5c3c..1ed1af3 100644 --- a/native/ops/basic.cuh +++ b/native/ops/basic.cuh @@ -6,26 +6,54 @@ namespace pygpukit { namespace ops { +// ============================================================================ +// Binary Element-wise Operations +// ============================================================================ + // Element-wise addition: c = a + b void add(const GPUArray& a, const GPUArray& b, GPUArray& c); +GPUArray add(const GPUArray& a, const GPUArray& b); + +// Element-wise subtraction: c = a - b +void sub(const GPUArray& a, const GPUArray& b, GPUArray& c); +GPUArray sub(const GPUArray& a, const GPUArray& b); // Element-wise multiplication: c = a * b void mul(const GPUArray& a, const GPUArray& b, GPUArray& c); +GPUArray mul(const GPUArray& a, const GPUArray& b); + +// Element-wise division: c = a / b +void div(const GPUArray& a, const GPUArray& b, GPUArray& c); +GPUArray div(const GPUArray& a, const GPUArray& b); + +// ============================================================================ +// Unary Element-wise Operations (float32/float64 only) +// ============================================================================ + +// Element-wise exponential: c = exp(a) +void exp(const GPUArray& a, GPUArray& c); +GPUArray exp(const GPUArray& a); + +// Element-wise natural logarithm: c = log(a) +void log(const GPUArray& a, GPUArray& c); +GPUArray log(const GPUArray& a); + +// Element-wise ReLU: c = max(0, a) +void relu(const GPUArray& a, GPUArray& c); +GPUArray relu(const GPUArray& a); + +// ============================================================================ +// Matrix Operations +// ============================================================================ // Matrix multiplication: c = a @ b // a: (M, K), b: (K, N), c: (M, N) void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c); +GPUArray matmul(const GPUArray& a, const GPUArray& b); // Matrix multiplication with explicit TF32 control // use_tf32: force TF32 TensorCore path (requires SM >= 80 and float32) void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c, bool use_tf32); - -// Convenience functions that return new arrays -GPUArray add(const GPUArray& a, const GPUArray& b); -GPUArray mul(const GPUArray& a, const GPUArray& b); -GPUArray matmul(const GPUArray& a, const GPUArray& b); - -// Matmul with explicit TF32 control GPUArray matmul(const GPUArray& a, const GPUArray& b, bool use_tf32); } // namespace ops diff --git a/src/pygpukit/__init__.py b/src/pygpukit/__init__.py index b10d820..b952bc4 100644 --- a/src/pygpukit/__init__.py +++ b/src/pygpukit/__init__.py @@ -27,7 +27,7 @@ jit, warmup, ) -from pygpukit.ops.basic import add, matmul, mul +from pygpukit.ops.basic import add, div, exp, log, matmul, mul, relu, sub # Try to import Rust types, fallback to Python implementations try: @@ -79,6 +79,11 @@ "check_driver_compatibility", # Operations "add", + "sub", "mul", + "div", + "exp", + "log", + "relu", "matmul", ] diff --git a/src/pygpukit/ops/basic.py b/src/pygpukit/ops/basic.py index 0ed444e..b10aa77 100644 --- a/src/pygpukit/ops/basic.py +++ b/src/pygpukit/ops/basic.py @@ -21,6 +21,13 @@ def _validate_same_dtype(a: GPUArray, b: GPUArray, op_name: str) -> None: raise ValueError(f"{op_name} requires arrays of same dtype, got {a.dtype} and {b.dtype}") +def _validate_float_dtype(a: GPUArray, op_name: str) -> None: + """Validate that array has float dtype.""" + from pygpukit.core.dtypes import float32, float64 + if a.dtype not in (float32, float64): + raise ValueError(f"{op_name} requires float32 or float64 dtype, got {a.dtype}") + + def add(a: GPUArray, b: GPUArray) -> GPUArray: """Element-wise addition of two arrays. @@ -121,6 +128,211 @@ def _mul_native(a: GPUArray, b: GPUArray) -> GPUArray: return GPUArray._wrap_native(c_native) +def sub(a: GPUArray, b: GPUArray) -> GPUArray: + """Element-wise subtraction of two arrays. + + Args: + a: First input array. + b: Second input array. + + Returns: + A new GPUArray containing the element-wise difference. + + Raises: + ValueError: If shapes don't match. + """ + _validate_same_shape(a, b, "sub") + _validate_same_dtype(a, b, "sub") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _sub_native(a, b) + else: + return _sub_cpu(a, b) + + +def _sub_cpu(a: GPUArray, b: GPUArray) -> GPUArray: + """CPU implementation of sub.""" + a_np = a.to_numpy() + b_np = b.to_numpy() + result_np = a_np - b_np + return from_numpy(result_np) + + +def _sub_native(a: GPUArray, b: GPUArray) -> GPUArray: + """Native C++ CUDA implementation of sub (zero-copy).""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + a_native = a._get_native() + b_native = b._get_native() + c_native = native.sub(a_native, b_native) + return GPUArray._wrap_native(c_native) + + +def div(a: GPUArray, b: GPUArray) -> GPUArray: + """Element-wise division of two arrays. + + Args: + a: First input array (dividend). + b: Second input array (divisor). + + Returns: + A new GPUArray containing the element-wise quotient. + + Raises: + ValueError: If shapes don't match. + """ + _validate_same_shape(a, b, "div") + _validate_same_dtype(a, b, "div") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _div_native(a, b) + else: + return _div_cpu(a, b) + + +def _div_cpu(a: GPUArray, b: GPUArray) -> GPUArray: + """CPU implementation of div.""" + a_np = a.to_numpy() + b_np = b.to_numpy() + result_np = a_np / b_np + return from_numpy(result_np) + + +def _div_native(a: GPUArray, b: GPUArray) -> GPUArray: + """Native C++ CUDA implementation of div (zero-copy).""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + a_native = a._get_native() + b_native = b._get_native() + c_native = native.div(a_native, b_native) + return GPUArray._wrap_native(c_native) + + +def exp(a: GPUArray) -> GPUArray: + """Element-wise exponential. + + Args: + a: Input array (float32 or float64). + + Returns: + A new GPUArray containing exp(a). + + Raises: + ValueError: If dtype is not float32 or float64. + """ + _validate_float_dtype(a, "exp") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _exp_native(a) + else: + return _exp_cpu(a) + + +def _exp_cpu(a: GPUArray) -> GPUArray: + """CPU implementation of exp.""" + a_np = a.to_numpy() + result_np = np.exp(a_np) + return from_numpy(result_np) + + +def _exp_native(a: GPUArray) -> GPUArray: + """Native C++ CUDA implementation of exp (zero-copy).""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + a_native = a._get_native() + c_native = native.exp(a_native) + return GPUArray._wrap_native(c_native) + + +def log(a: GPUArray) -> GPUArray: + """Element-wise natural logarithm. + + Args: + a: Input array (float32 or float64). + + Returns: + A new GPUArray containing log(a). + + Raises: + ValueError: If dtype is not float32 or float64. + """ + _validate_float_dtype(a, "log") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _log_native(a) + else: + return _log_cpu(a) + + +def _log_cpu(a: GPUArray) -> GPUArray: + """CPU implementation of log.""" + a_np = a.to_numpy() + result_np = np.log(a_np) + return from_numpy(result_np) + + +def _log_native(a: GPUArray) -> GPUArray: + """Native C++ CUDA implementation of log (zero-copy).""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + a_native = a._get_native() + c_native = native.log(a_native) + return GPUArray._wrap_native(c_native) + + +def relu(a: GPUArray) -> GPUArray: + """Element-wise ReLU (Rectified Linear Unit). + + Computes max(0, x) for each element. + + Args: + a: Input array (float32 or float64). + + Returns: + A new GPUArray containing relu(a). + + Raises: + ValueError: If dtype is not float32 or float64. + """ + _validate_float_dtype(a, "relu") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _relu_native(a) + else: + return _relu_cpu(a) + + +def _relu_cpu(a: GPUArray) -> GPUArray: + """CPU implementation of relu.""" + a_np = a.to_numpy() + result_np = np.maximum(0, a_np) + return from_numpy(result_np) + + +def _relu_native(a: GPUArray) -> GPUArray: + """Native C++ CUDA implementation of relu (zero-copy).""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + a_native = a._get_native() + c_native = native.relu(a_native) + return GPUArray._wrap_native(c_native) + + def matmul(a: GPUArray, b: GPUArray, *, use_tf32: bool | None = None) -> GPUArray: """Matrix multiplication of two 2D arrays. From 1e816b6fbbb8bcdc653df2a313ded885a7a9343e Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 00:00:09 +0900 Subject: [PATCH 04/24] docs(claude): add TF32 optimization research for Issue #53 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Research findings for CUTLASS-level TF32 kernel optimization: - Swizzled shared memory layout (XOR-based bank conflict elimination) - ldmatrix instruction usage and TF32 limitations - Multi-stage pipeline considerations (4-stage vs current 2-stage) - Recommended implementation order with expected gains - Reference materials from NVIDIA CUTLASS and academic papers Current: 27.38 TFLOPS → Target: 35+ TFLOPS 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- CLAUDE.md | 93 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index f98e710..5536d50 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -362,6 +362,99 @@ store_matrix_sync(C, c_frag, N, mem_row_major); --- +## TF32 Optimization Research (Issue #53) + +### Current Performance Status + +| Metric | Value | +|--------|-------| +| Current | **27.38 TFLOPS** (8192×8192) | +| RTX 3090 Ti TF32 Theoretical | ~40 TFLOPS | +| cuBLAS Reference | ~59 TFLOPS | +| Gap to cuBLAS | **47%** | + +### Current Implementation Parameters + +``` +Block Tile: BM=128, BN=128, BK=16 +Warp Tile: WARP_TILES_M=2, WARP_TILES_N=8 (32×64 per warp) +MMA Instruction: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 +Pipeline: 2-stage double buffering +Thread Block: 256 threads (8 warps) +Shared Memory: ~37KB/block → occupancy ~16.7% +``` + +### CUTLASS Optimization Techniques + +#### 1. Swizzled Shared Memory Layout (High Priority) + +Current implementation uses simple padding (`A_PAD=4, B_PAD=4`) but bank conflicts are not fully eliminated. + +**CUTLASS Approach:** +```cpp +// XOR-based swizzle pattern +int store_column = (lane_id % 8) ^ (lane_id / 8); +``` + +- Store and Load phases use transposed index relationship +- XOR operation applied per 8×8 block unit +- Combined with `ldmatrix` for fully bank conflict-free access + +**Key Insight:** +> "the indexing in the 'Loading from Shared Memory to Registers' slide is transposed from the indexing in 'Load from Global/Store to Shared' slide." + +#### 2. ldmatrix Instruction (High Priority) + +Current implementation manually loads from shared memory to registers: +```cpp +// Current implementation +float a0 = smA[curr][tile_m + a_row_base][kk + a_col_base]; +``` + +**CUTLASS Approach:** +- Uses `ldmatrix.sync.aligned.m8n8.x4.shared.b16` +- Single instruction loads four 8×8 matrices (entire warp) + +**TF32 Limitation:** +> "ldmatrix cannot transpose 32-bit data. CUTLASS uses 32-bit shared memory load to load data from shared memory to the registers to do the transpose right before calling tf32 tensor core." + +#### 3. Multi-stage Pipeline (Medium-High Priority) + +Current: 2-stage → CUTLASS default: **4-stage** + +**Past Failed Attempt:** +> "3-stage pipeline: -28% (50% more smem reduced occupancy)" + +**Considerations:** +- Trade-off between shared memory usage and occupancy +- RTX 3090 Ti: 100KB/SM available +- Current 37KB → 4-stage at ~74KB should fit + +### Recommended Implementation Order + +| Priority | Optimization | Expected Gain | Difficulty | +|----------|-------------|---------------|------------| +| 1 | Swizzled shared memory layout | +10-15% | Medium | +| 2 | 4-stage pipeline (proper smem sizing) | +5-10% | Medium | +| 3 | Warp tile tuning (BM/BN/BK re-tuning) | +5-10% | Low | +| 4 | Epilogue fusion (bias + activation) | Memory reduction | Medium | + +### Path to 35 TFLOPS + +- Current: 27.38 TFLOPS (68% of target) +- Swizzle + 4-stage: 32-34 TFLOPS expected +- Fine-tuning: 35+ TFLOPS + +### Reference Materials + +- [CUTLASS TF32 GEMM Example](https://github.com/NVIDIA/cutlass/blob/main/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) +- [CUTLASS Efficient GEMM Documentation](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/efficient_gemm.html) +- [CUTLASS Swizzled Layouts Discussion](https://github.com/NVIDIA/cutlass/discussions/1130) +- [Understanding CUTLASS Permuted Shared Memory](https://forums.developer.nvidia.com/t/understanding-cutlass-permuted-shared-memory-layout/303697) +- [Dissecting Tensor Cores (Academic Paper)](https://arxiv.org/pdf/2206.02874) + +--- + ## Development Workflow ### Kernel Development Cycle From e27f1fa4f93de49392bc24eb5a6f6f35b79e492f Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 00:17:53 +0900 Subject: [PATCH 05/24] wip(tf32): add v2 kernel baseline (#53) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TF32 v2 kernel baseline (same structure as v1): - BM=128, BN=128, BK=16 - 2-stage pipeline - Padding for bank conflicts Benchmark results (RTX 3090 Ti): - 2048x2048: 11.12 TFLOPS - 4096x4096: 20.46 TFLOPS - 8192x8192: 29.12 TFLOPS cuBLAS reference: - 8192x8192: 41.79 TFLOPS Current efficiency: 70% of cuBLAS Target: 90% (37.6 TFLOPS) Correctness: PASS (p99 < 2%) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- native/ops/basic.cu | 21 ++- native/ops/matmul_f32_tf32_v2.cuh | 266 ++++++++++++++++++++++++++++++ 2 files changed, 282 insertions(+), 5 deletions(-) create mode 100644 native/ops/matmul_f32_tf32_v2.cuh diff --git a/native/ops/basic.cu b/native/ops/basic.cu index 47ed5b6..0a853b2 100644 --- a/native/ops/basic.cu +++ b/native/ops/basic.cu @@ -4,6 +4,7 @@ #include "basic.cuh" #include "matmul_f32_ampere.cuh" #include "matmul_f32_tf32.cuh" +#include "matmul_f32_tf32_v2.cuh" #include "../core/driver_context.hpp" #include #include @@ -1333,11 +1334,21 @@ static void matmul_impl(const GPUArray& a, const GPUArray& b, GPUArray& c, bool static_cast(c.data()), M, N, K); } else { - tf32::launch_sgemm_tf32( - static_cast(a.data()), - static_cast(b.data()), - static_cast(c.data()), - M, N, K); + // Check for v2 kernel (optimized) via environment variable + const char* use_v2 = std::getenv("PYGPUKIT_TF32_V2"); + if (use_v2 && std::string(use_v2) == "1") { + tf32_v2::launch_sgemm_tf32_v2( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + } else { + tf32::launch_sgemm_tf32( + static_cast(a.data()), + static_cast(b.data()), + static_cast(c.data()), + M, N, K); + } } } else if (use_optimized) { ampere::launch_sgemm_ampere( diff --git a/native/ops/matmul_f32_tf32_v2.cuh b/native/ops/matmul_f32_tf32_v2.cuh new file mode 100644 index 0000000..ea0b147 --- /dev/null +++ b/native/ops/matmul_f32_tf32_v2.cuh @@ -0,0 +1,266 @@ +/** + * TF32 TensorCore GEMM v2 - CUTLASS-inspired Optimizations + * + * Target: 90%+ of cuBLAS performance (37.6+ TFLOPS on RTX 3090 Ti) + * + * Key optimizations: + * 1. 3-stage software pipeline with cp.async + * 2. Optimized warp tile configuration + * 3. Vectorized memory loads + */ + +#pragma once +#include +#include +#include + +namespace pygpukit { +namespace ops { +namespace tf32_v2 { + +// ============================================================================ +// Configuration - Tuned for RTX 3090 Ti (SM 8.6) +// ============================================================================ + +// CTA tile dimensions +constexpr int BM = 128; +constexpr int BN = 128; +constexpr int BK = 16; + +// MMA tile dimensions (m16n8k8 instruction) +constexpr int WMMA_M = 16; +constexpr int WMMA_N = 8; +constexpr int WMMA_K = 8; + +// Warp configuration: 4x2 warps = 8 warps = 256 threads +constexpr int WARPS_M = 4; +constexpr int WARPS_N = 2; +constexpr int NUM_WARPS = WARPS_M * WARPS_N; +constexpr int NUM_THREADS = NUM_WARPS * 32; + +// Each warp computes WARP_TILES_M x WARP_TILES_N MMA tiles +constexpr int WARP_TILES_M = 2; // 2 * 16 = 32 rows per warp +constexpr int WARP_TILES_N = 8; // 8 * 8 = 64 cols per warp + +// Pipeline stages +// smA: 2 * 128 * 16 * 4 = 16384 bytes +// smB: 2 * 16 * 128 * 4 = 16384 bytes +// Total: 32768 bytes = 32KB (easily fits) +constexpr int STAGES = 2; + +// Padding to avoid bank conflicts +constexpr int A_PAD = 4; +constexpr int B_PAD = 4; + +// ============================================================================ +// cp.async Helpers +// ============================================================================ + +__device__ __forceinline__ uint32_t smem_u32(const void* ptr) { + uint32_t addr; + asm volatile( + "{ .reg .u64 smem64; " + " cvta.to.shared.u64 smem64, %1; " + " cvt.u32.u64 %0, smem64; }" + : "=r"(addr) : "l"(ptr) + ); + return addr; +} + +__device__ __forceinline__ void cp_async_16(void* smem, const void* gmem) { + uint32_t addr = smem_u32(smem); + asm volatile( + "cp.async.cg.shared.global [%0], [%1], 16;" + :: "r"(addr), "l"(gmem) + ); +} + +__device__ __forceinline__ void cp_async_commit() { + asm volatile("cp.async.commit_group;"); +} + +__device__ __forceinline__ void cp_async_wait_0() { + asm volatile("cp.async.wait_group 0;"); +} + +__device__ __forceinline__ void cp_async_wait_1() { + asm volatile("cp.async.wait_group 1;"); +} + +// ============================================================================ +// Main Kernel - Correct Implementation with Double Buffering +// ============================================================================ + +__global__ void __launch_bounds__(256, 2) +sgemm_tf32_v2_kernel( + const float* __restrict__ A, + const float* __restrict__ B, + float* __restrict__ C, + int M, int N, int K +) { + const int tid = threadIdx.x; + const int warp_id = tid >> 5; + const int lane = tid & 31; + + const int cta_m = blockIdx.y * BM; + const int cta_n = blockIdx.x * BN; + + const int warp_row = warp_id / WARPS_N; // 0-3 + const int warp_col = warp_id % WARPS_N; // 0-1 + + // Warp's starting position within CTA tile + const int warp_m = warp_row * (WARP_TILES_M * WMMA_M); // 0, 32, 64, 96 + const int warp_n = warp_col * (WARP_TILES_N * WMMA_N); // 0, 64 + + // Shared memory with padding + __shared__ float smA[STAGES][BM][BK + A_PAD]; + __shared__ float smB[STAGES][BK][BN + B_PAD]; + + // Accumulators + float acc[WARP_TILES_M][WARP_TILES_N][4] = {}; + + const int num_k_tiles = K / BK; + + // Fragment index mappings (verified) + const int a_row_base = lane / 4; // 0-7 + const int a_col_base = lane % 4; // 0-3 + const int b_row_base = lane % 4; // 0-3 + const int b_col = lane / 4; // 0-7 + const int c_row_base = lane / 4; + const int c_col_base = (lane % 4) * 2; + + // ====== Load helpers ====== + auto load_A_async = [&](int stage, int kt) { + // Each thread loads: tid / 4 determines which row, tid % 4 * 4 determines col + const int a_row = tid / 4; // 0-63 + const int a_col = (tid % 4) * 4; // 0, 4, 8, 12 + + #pragma unroll + for (int i = 0; i < 2; ++i) { + int row = a_row + i * 64; // 0-63, then 64-127 + int gm = cta_m + row; + int gk = kt * BK + a_col; + + if (gm < M && gk < K) { + cp_async_16(&smA[stage][row][a_col], &A[gm * K + gk]); + } + } + }; + + auto load_B_async = [&](int stage, int kt) { + const int b_row_ld = tid / 32; // 0-7 + const int b_col_ld = (tid % 32) * 4; // 0, 4, 8, ..., 124 + + #pragma unroll + for (int i = 0; i < 2; ++i) { + int k = b_row_ld + i * 8; // 0-7, then 8-15 + int gk = kt * BK + k; + int gn = cta_n + b_col_ld; + + if (gk < K && gn < N) { + cp_async_16(&smB[stage][k][b_col_ld], &B[gk * N + gn]); + } + } + }; + + // ====== Prologue: load first tile ====== + load_A_async(0, 0); + load_B_async(0, 0); + cp_async_commit(); + cp_async_wait_0(); + __syncthreads(); + + // ====== Main loop with double buffering ====== + for (int kt = 0; kt < num_k_tiles; ++kt) { + int curr = kt & 1; + int next = curr ^ 1; + + // Prefetch next tile + if (kt + 1 < num_k_tiles) { + load_A_async(next, kt + 1); + load_B_async(next, kt + 1); + } + cp_async_commit(); + + // Process current tile + #pragma unroll + for (int kk = 0; kk < BK; kk += WMMA_K) { + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + int tile_m = warp_m + wm * WMMA_M; + + // Load A fragment + float a0 = smA[curr][tile_m + a_row_base][kk + a_col_base]; + float a1 = smA[curr][tile_m + a_row_base + 8][kk + a_col_base]; + float a2 = smA[curr][tile_m + a_row_base][kk + a_col_base + 4]; + float a3 = smA[curr][tile_m + a_row_base + 8][kk + a_col_base + 4]; + + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_n = warp_n + wn * WMMA_N; + + // Load B fragment + float b0 = smB[curr][kk + b_row_base][tile_n + b_col]; + float b1 = smB[curr][kk + b_row_base + 4][tile_n + b_col]; + + // MMA instruction + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%0, %1, %2, %3};" + : "+f"(acc[wm][wn][0]), "+f"(acc[wm][wn][1]), + "+f"(acc[wm][wn][2]), "+f"(acc[wm][wn][3]) + : "r"(__float_as_uint(a0)), "r"(__float_as_uint(a1)), + "r"(__float_as_uint(a2)), "r"(__float_as_uint(a3)), + "r"(__float_as_uint(b0)), "r"(__float_as_uint(b1)) + ); + } + } + } + + // Wait for prefetch + cp_async_wait_0(); + __syncthreads(); + } + + // ====== Epilogue: Write results ====== + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_m = cta_m + warp_m + wm * WMMA_M; + int tile_n = cta_n + warp_n + wn * WMMA_N; + + int out_row0 = tile_m + c_row_base; + int out_row1 = tile_m + c_row_base + 8; + int out_col0 = tile_n + c_col_base; + int out_col1 = tile_n + c_col_base + 1; + + if (out_row0 < M && out_col0 < N) C[out_row0 * N + out_col0] = acc[wm][wn][0]; + if (out_row0 < M && out_col1 < N) C[out_row0 * N + out_col1] = acc[wm][wn][1]; + if (out_row1 < M && out_col0 < N) C[out_row1 * N + out_col0] = acc[wm][wn][2]; + if (out_row1 < M && out_col1 < N) C[out_row1 * N + out_col1] = acc[wm][wn][3]; + } + } +} + +// ============================================================================ +// Launch Helper +// ============================================================================ + +inline cudaError_t launch_sgemm_tf32_v2( + const float* A, const float* B, float* C, + int M, int N, int K, + cudaStream_t stream = 0 +) { + dim3 block(256); + dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM); + sgemm_tf32_v2_kernel<<>>(A, B, C, M, N, K); + return cudaGetLastError(); +} + +} // namespace tf32_v2 +} // namespace ops +} // namespace pygpukit From 4e158b1f71a77c35cad5c890fdfbaff7942e3ca8 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 00:27:32 +0900 Subject: [PATCH 06/24] wip(tf32-v2): WMMA API version - regression Benchmark results (RTX 3090 Ti): - 2048x2048: 10.83 TFLOPS (cuBLAS: 30.60) - 4096x4096: 19.09 TFLOPS (cuBLAS: 35.49) - 8192x8192: 25.48 TFLOPS (cuBLAS: 41.79) Correctness: PASS Note: WMMA API is slower than PTX mma.sync (was ~29 TFLOPS). Need to return to PTX with better optimizations. --- bench_tf32_v2.py | 92 +++++++++++++++++++ native/ops/matmul_f32_tf32_v2.cuh | 141 +++++++++++------------------- 2 files changed, 141 insertions(+), 92 deletions(-) create mode 100644 bench_tf32_v2.py diff --git a/bench_tf32_v2.py b/bench_tf32_v2.py new file mode 100644 index 0000000..9264edf --- /dev/null +++ b/bench_tf32_v2.py @@ -0,0 +1,92 @@ +"""TF32 v2 Kernel Benchmark""" +import os +import numpy as np +import time + +# Enable v2 kernel +os.environ["PYGPUKIT_TF32_V2"] = "1" + +def benchmark(): + import pygpukit as gk + + if not gk.is_cuda_available(): + print("CUDA not available") + return + + info = gk.get_device_info() + print(f"Device: {info.name}") + print(f"Using TF32 v2 kernel: PYGPUKIT_TF32_V2={os.environ.get('PYGPUKIT_TF32_V2', '0')}") + + sizes = [2048, 4096, 8192] + + print("\n" + "=" * 50) + print("Performance Benchmark (TF32 v2)") + print("=" * 50) + + for N in sizes: + M, K = N, N + + a_np = np.random.randn(M, K).astype(np.float32) + b_np = np.random.randn(K, N).astype(np.float32) + + a = gk.from_numpy(a_np) + b = gk.from_numpy(b_np) + + # Warmup + for _ in range(5): + c = gk.matmul(a, b, use_tf32=True) + + # Benchmark + num_iters = 20 + start = time.perf_counter() + for _ in range(num_iters): + c = gk.matmul(a, b, use_tf32=True) + elapsed = time.perf_counter() - start + + avg_time_ms = (elapsed / num_iters) * 1000 + flops = 2.0 * M * N * K + tflops = (flops / (avg_time_ms / 1000)) / 1e12 + + print(f"{N}x{N}x{N}: {avg_time_ms:.2f} ms, {tflops:.2f} TFLOPS") + + # Correctness check + print("\n" + "=" * 50) + print("Correctness Check") + print("=" * 50) + + all_pass = True + for N in [256, 512, 1024, 2048]: + a_np = np.random.randn(N, N).astype(np.float32) + b_np = np.random.randn(N, N).astype(np.float32) + + a = gk.from_numpy(a_np) + b = gk.from_numpy(b_np) + + c = gk.matmul(a, b, use_tf32=True) + c_np = c.to_numpy() + + expected = a_np @ b_np + + abs_error = np.abs(c_np - expected) + scale = np.maximum(np.abs(expected), np.abs(c_np)) + scale = np.maximum(scale, 1.0) + rel_error = abs_error / scale + max_rel_error = np.max(rel_error) + mean_rel_error = np.mean(rel_error) + p99_rel_error = np.percentile(rel_error, 99) + + # TF32 has 10 mantissa bits, allow up to 2% error for large matmuls + status = "PASS" if p99_rel_error < 2e-2 else "FAIL" + if status == "FAIL": + all_pass = False + print(f" {N}x{N}: max={max_rel_error:.6f}, mean={mean_rel_error:.6f}, p99={p99_rel_error:.6f} [{status}]") + + print("\n" + "=" * 50) + print(f"Overall: {'PASS' if all_pass else 'FAIL'}") + print("=" * 50) + +if __name__ == "__main__": + print("=" * 60) + print("TF32 v2 Kernel Benchmark") + print("=" * 60) + benchmark() diff --git a/native/ops/matmul_f32_tf32_v2.cuh b/native/ops/matmul_f32_tf32_v2.cuh index ea0b147..a25a301 100644 --- a/native/ops/matmul_f32_tf32_v2.cuh +++ b/native/ops/matmul_f32_tf32_v2.cuh @@ -1,12 +1,9 @@ /** - * TF32 TensorCore GEMM v2 - CUTLASS-inspired Optimizations + * TF32 TensorCore GEMM v2 - Using WMMA API * * Target: 90%+ of cuBLAS performance (37.6+ TFLOPS on RTX 3090 Ti) * - * Key optimizations: - * 1. 3-stage software pipeline with cp.async - * 2. Optimized warp tile configuration - * 3. Vectorized memory loads + * This version uses the nvcuda::wmma API for cleaner fragment handling. */ #pragma once @@ -14,46 +11,40 @@ #include #include +using namespace nvcuda; + namespace pygpukit { namespace ops { namespace tf32_v2 { // ============================================================================ -// Configuration - Tuned for RTX 3090 Ti (SM 8.6) +// Configuration // ============================================================================ -// CTA tile dimensions constexpr int BM = 128; constexpr int BN = 128; constexpr int BK = 16; -// MMA tile dimensions (m16n8k8 instruction) +// WMMA dimensions constexpr int WMMA_M = 16; -constexpr int WMMA_N = 8; +constexpr int WMMA_N = 16; constexpr int WMMA_K = 8; -// Warp configuration: 4x2 warps = 8 warps = 256 threads +// Warp configuration: 4x2 warps constexpr int WARPS_M = 4; constexpr int WARPS_N = 2; -constexpr int NUM_WARPS = WARPS_M * WARPS_N; -constexpr int NUM_THREADS = NUM_WARPS * 32; +constexpr int NUM_THREADS = WARPS_M * WARPS_N * 32; -// Each warp computes WARP_TILES_M x WARP_TILES_N MMA tiles -constexpr int WARP_TILES_M = 2; // 2 * 16 = 32 rows per warp -constexpr int WARP_TILES_N = 8; // 8 * 8 = 64 cols per warp +// Each warp computes multiple WMMA tiles +constexpr int WARP_TILES_M = 2; // 32 rows per warp +constexpr int WARP_TILES_N = 4; // 64 cols per warp -// Pipeline stages -// smA: 2 * 128 * 16 * 4 = 16384 bytes -// smB: 2 * 16 * 128 * 4 = 16384 bytes -// Total: 32768 bytes = 32KB (easily fits) constexpr int STAGES = 2; - -// Padding to avoid bank conflicts constexpr int A_PAD = 4; constexpr int B_PAD = 4; // ============================================================================ -// cp.async Helpers +// cp.async helpers // ============================================================================ __device__ __forceinline__ uint32_t smem_u32(const void* ptr) { @@ -69,10 +60,7 @@ __device__ __forceinline__ uint32_t smem_u32(const void* ptr) { __device__ __forceinline__ void cp_async_16(void* smem, const void* gmem) { uint32_t addr = smem_u32(smem); - asm volatile( - "cp.async.cg.shared.global [%0], [%1], 16;" - :: "r"(addr), "l"(gmem) - ); + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" :: "r"(addr), "l"(gmem)); } __device__ __forceinline__ void cp_async_commit() { @@ -83,12 +71,8 @@ __device__ __forceinline__ void cp_async_wait_0() { asm volatile("cp.async.wait_group 0;"); } -__device__ __forceinline__ void cp_async_wait_1() { - asm volatile("cp.async.wait_group 1;"); -} - // ============================================================================ -// Main Kernel - Correct Implementation with Double Buffering +// Main Kernel using WMMA API // ============================================================================ __global__ void __launch_bounds__(256, 2) @@ -105,39 +89,36 @@ sgemm_tf32_v2_kernel( const int cta_m = blockIdx.y * BM; const int cta_n = blockIdx.x * BN; - const int warp_row = warp_id / WARPS_N; // 0-3 - const int warp_col = warp_id % WARPS_N; // 0-1 + const int warp_row = warp_id / WARPS_N; + const int warp_col = warp_id % WARPS_N; - // Warp's starting position within CTA tile - const int warp_m = warp_row * (WARP_TILES_M * WMMA_M); // 0, 32, 64, 96 - const int warp_n = warp_col * (WARP_TILES_N * WMMA_N); // 0, 64 + const int warp_m = warp_row * (WARP_TILES_M * WMMA_M); + const int warp_n = warp_col * (WARP_TILES_N * WMMA_N); - // Shared memory with padding __shared__ float smA[STAGES][BM][BK + A_PAD]; __shared__ float smB[STAGES][BK][BN + B_PAD]; - // Accumulators - float acc[WARP_TILES_M][WARP_TILES_N][4] = {}; + // WMMA fragments for accumulators + wmma::fragment acc[WARP_TILES_M][WARP_TILES_N]; - const int num_k_tiles = K / BK; + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + wmma::fill_fragment(acc[wm][wn], 0.0f); + } + } - // Fragment index mappings (verified) - const int a_row_base = lane / 4; // 0-7 - const int a_col_base = lane % 4; // 0-3 - const int b_row_base = lane % 4; // 0-3 - const int b_col = lane / 4; // 0-7 - const int c_row_base = lane / 4; - const int c_col_base = (lane % 4) * 2; + const int num_k_tiles = K / BK; - // ====== Load helpers ====== + // Load helpers auto load_A_async = [&](int stage, int kt) { - // Each thread loads: tid / 4 determines which row, tid % 4 * 4 determines col - const int a_row = tid / 4; // 0-63 - const int a_col = (tid % 4) * 4; // 0, 4, 8, 12 + const int a_row = tid / 4; + const int a_col = (tid % 4) * 4; #pragma unroll for (int i = 0; i < 2; ++i) { - int row = a_row + i * 64; // 0-63, then 64-127 + int row = a_row + i * 64; int gm = cta_m + row; int gk = kt * BK + a_col; @@ -148,12 +129,12 @@ sgemm_tf32_v2_kernel( }; auto load_B_async = [&](int stage, int kt) { - const int b_row_ld = tid / 32; // 0-7 - const int b_col_ld = (tid % 32) * 4; // 0, 4, 8, ..., 124 + const int b_row_ld = tid / 32; + const int b_col_ld = (tid % 32) * 4; #pragma unroll for (int i = 0; i < 2; ++i) { - int k = b_row_ld + i * 8; // 0-7, then 8-15 + int k = b_row_ld + i * 8; int gk = kt * BK + k; int gn = cta_n + b_col_ld; @@ -163,69 +144,51 @@ sgemm_tf32_v2_kernel( } }; - // ====== Prologue: load first tile ====== + // Prologue load_A_async(0, 0); load_B_async(0, 0); cp_async_commit(); cp_async_wait_0(); __syncthreads(); - // ====== Main loop with double buffering ====== + // Main loop for (int kt = 0; kt < num_k_tiles; ++kt) { int curr = kt & 1; int next = curr ^ 1; - // Prefetch next tile if (kt + 1 < num_k_tiles) { load_A_async(next, kt + 1); load_B_async(next, kt + 1); } cp_async_commit(); - // Process current tile + // Process current tile using WMMA #pragma unroll for (int kk = 0; kk < BK; kk += WMMA_K) { #pragma unroll for (int wm = 0; wm < WARP_TILES_M; ++wm) { int tile_m = warp_m + wm * WMMA_M; - // Load A fragment - float a0 = smA[curr][tile_m + a_row_base][kk + a_col_base]; - float a1 = smA[curr][tile_m + a_row_base + 8][kk + a_col_base]; - float a2 = smA[curr][tile_m + a_row_base][kk + a_col_base + 4]; - float a3 = smA[curr][tile_m + a_row_base + 8][kk + a_col_base + 4]; + wmma::fragment a_frag; + wmma::load_matrix_sync(a_frag, &smA[curr][tile_m][kk], BK + A_PAD); #pragma unroll for (int wn = 0; wn < WARP_TILES_N; ++wn) { int tile_n = warp_n + wn * WMMA_N; - // Load B fragment - float b0 = smB[curr][kk + b_row_base][tile_n + b_col]; - float b1 = smB[curr][kk + b_row_base + 4][tile_n + b_col]; - - // MMA instruction - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " - "{%0, %1, %2, %3}, " - "{%4, %5, %6, %7}, " - "{%8, %9}, " - "{%0, %1, %2, %3};" - : "+f"(acc[wm][wn][0]), "+f"(acc[wm][wn][1]), - "+f"(acc[wm][wn][2]), "+f"(acc[wm][wn][3]) - : "r"(__float_as_uint(a0)), "r"(__float_as_uint(a1)), - "r"(__float_as_uint(a2)), "r"(__float_as_uint(a3)), - "r"(__float_as_uint(b0)), "r"(__float_as_uint(b1)) - ); + wmma::fragment b_frag; + wmma::load_matrix_sync(b_frag, &smB[curr][kk][tile_n], BN + B_PAD); + + wmma::mma_sync(acc[wm][wn], a_frag, b_frag, acc[wm][wn]); } } } - // Wait for prefetch cp_async_wait_0(); __syncthreads(); } - // ====== Epilogue: Write results ====== + // Epilogue: write results #pragma unroll for (int wm = 0; wm < WARP_TILES_M; ++wm) { #pragma unroll @@ -233,15 +196,9 @@ sgemm_tf32_v2_kernel( int tile_m = cta_m + warp_m + wm * WMMA_M; int tile_n = cta_n + warp_n + wn * WMMA_N; - int out_row0 = tile_m + c_row_base; - int out_row1 = tile_m + c_row_base + 8; - int out_col0 = tile_n + c_col_base; - int out_col1 = tile_n + c_col_base + 1; - - if (out_row0 < M && out_col0 < N) C[out_row0 * N + out_col0] = acc[wm][wn][0]; - if (out_row0 < M && out_col1 < N) C[out_row0 * N + out_col1] = acc[wm][wn][1]; - if (out_row1 < M && out_col0 < N) C[out_row1 * N + out_col0] = acc[wm][wn][2]; - if (out_row1 < M && out_col1 < N) C[out_row1 * N + out_col1] = acc[wm][wn][3]; + if (tile_m < M && tile_n < N) { + wmma::store_matrix_sync(&C[tile_m * N + tile_n], acc[wm][wn], N, wmma::mem_row_major); + } } } } From 6c66943cf95f8ea48f5dc773c5379c1750a4bd43 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 00:31:46 +0900 Subject: [PATCH 07/24] wip(tf32-v2): PTX m16n8k8 with BK=32 - correct but slower Benchmark results (RTX 3090 Ti): - 2048x2048: 10.36 TFLOPS (cuBLAS: 30.60) - 4096x4096: 17.86 TFLOPS (cuBLAS: 35.49) - 8192x8192: 24.26 TFLOPS (cuBLAS: 41.79) Correctness: PASS Note: BK=32 reduces occupancy. Need to optimize. --- native/ops/matmul_f32_tf32_v2.cuh | 271 ++++++++++++++++++++++-------- 1 file changed, 203 insertions(+), 68 deletions(-) diff --git a/native/ops/matmul_f32_tf32_v2.cuh b/native/ops/matmul_f32_tf32_v2.cuh index a25a301..fe8e5be 100644 --- a/native/ops/matmul_f32_tf32_v2.cuh +++ b/native/ops/matmul_f32_tf32_v2.cuh @@ -1,9 +1,13 @@ /** - * TF32 TensorCore GEMM v2 - Using WMMA API + * TF32 TensorCore GEMM v2 - ldmatrix + swizzled shared memory * * Target: 90%+ of cuBLAS performance (37.6+ TFLOPS on RTX 3090 Ti) * - * This version uses the nvcuda::wmma API for cleaner fragment handling. + * Optimizations: + * 1. ldmatrix.sync for efficient shared→register transfers + * 2. Swizzled shared memory to eliminate bank conflicts + * 3. Aggressive register blocking (more mma per smem load) + * 4. 2-stage double buffering with cp.async */ #pragma once @@ -11,8 +15,6 @@ #include #include -using namespace nvcuda; - namespace pygpukit { namespace ops { namespace tf32_v2 { @@ -23,28 +25,32 @@ namespace tf32_v2 { constexpr int BM = 128; constexpr int BN = 128; -constexpr int BK = 16; +constexpr int BK = 32; // Larger K-tile for better compute intensity -// WMMA dimensions -constexpr int WMMA_M = 16; -constexpr int WMMA_N = 16; -constexpr int WMMA_K = 8; +// PTX mma dimensions: m16n8k8 +constexpr int MMA_M = 16; +constexpr int MMA_N = 8; +constexpr int MMA_K = 8; -// Warp configuration: 4x2 warps +// Warp configuration: 4x2 warps per block (256 threads) constexpr int WARPS_M = 4; constexpr int WARPS_N = 2; constexpr int NUM_THREADS = WARPS_M * WARPS_N * 32; -// Each warp computes multiple WMMA tiles -constexpr int WARP_TILES_M = 2; // 32 rows per warp -constexpr int WARP_TILES_N = 4; // 64 cols per warp +// Each warp computes 32x64 output (2x8 = 16 mma tiles) +constexpr int WARP_M = 32; // 2 x MMA_M +constexpr int WARP_N = 64; // 8 x MMA_N constexpr int STAGES = 2; -constexpr int A_PAD = 4; -constexpr int B_PAD = 4; + +// Shared memory with swizzle (XOR-based) +// For TF32, each element is 4 bytes. A 128-element row would have 32 banks accessed. +// Swizzle pattern: smem[row][col] -> smem[row][col ^ ((row % 4) * 4)] +constexpr int SMEM_A_STRIDE = BK + 4; // Padding to avoid bank conflicts +constexpr int SMEM_B_STRIDE = BN + 4; // ============================================================================ -// cp.async helpers +// Inline PTX helpers // ============================================================================ __device__ __forceinline__ uint32_t smem_u32(const void* ptr) { @@ -71,8 +77,55 @@ __device__ __forceinline__ void cp_async_wait_0() { asm volatile("cp.async.wait_group 0;"); } +// ldmatrix: load 4 x 8x8 matrices from shared memory +__device__ __forceinline__ void ldmatrix_x4(uint32_t* r0, uint32_t* r1, uint32_t* r2, uint32_t* r3, const void* smem) { + uint32_t addr = smem_u32(smem); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(*r0), "=r"(*r1), "=r"(*r2), "=r"(*r3) + : "r"(addr) + ); +} + +// ldmatrix: load 2 x 8x8 matrices +__device__ __forceinline__ void ldmatrix_x2(uint32_t* r0, uint32_t* r1, const void* smem) { + uint32_t addr = smem_u32(smem); + asm volatile( + "ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];" + : "=r"(*r0), "=r"(*r1) + : "r"(addr) + ); +} + +// ldmatrix for transposed B (col-major in shared) +__device__ __forceinline__ void ldmatrix_x2_trans(uint32_t* r0, uint32_t* r1, const void* smem) { + uint32_t addr = smem_u32(smem); + asm volatile( + "ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];" + : "=r"(*r0), "=r"(*r1) + : "r"(addr) + ); +} + +// TF32 mma.sync: m16n8k8 +__device__ __forceinline__ void mma_m16n8k8_tf32( + float* d0, float* d1, float* d2, float* d3, + uint32_t a0, uint32_t a1, uint32_t a2, uint32_t a3, + uint32_t b0, uint32_t b1, + float c0, float c1, float c2, float c3 +) { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" + : "=f"(*d0), "=f"(*d1), "=f"(*d2), "=f"(*d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3) + ); +} + // ============================================================================ -// Main Kernel using WMMA API +// Main Kernel // ============================================================================ __global__ void __launch_bounds__(256, 2) @@ -89,62 +142,89 @@ sgemm_tf32_v2_kernel( const int cta_m = blockIdx.y * BM; const int cta_n = blockIdx.x * BN; - const int warp_row = warp_id / WARPS_N; - const int warp_col = warp_id % WARPS_N; + // Warp position in CTA + const int warp_row = warp_id / WARPS_N; // 0-3 + const int warp_col = warp_id % WARPS_N; // 0-1 - const int warp_m = warp_row * (WARP_TILES_M * WMMA_M); - const int warp_n = warp_col * (WARP_TILES_N * WMMA_N); + const int warp_m = warp_row * WARP_M; // 0, 32, 64, 96 + const int warp_n = warp_col * WARP_N; // 0, 64 - __shared__ float smA[STAGES][BM][BK + A_PAD]; - __shared__ float smB[STAGES][BK][BN + B_PAD]; - - // WMMA fragments for accumulators - wmma::fragment acc[WARP_TILES_M][WARP_TILES_N]; + // Shared memory + __shared__ float smA[STAGES][BM][SMEM_A_STRIDE]; + __shared__ float smB[STAGES][BK][SMEM_B_STRIDE]; + // Accumulators: 2x8 = 16 mma tiles per warp + // Each mma produces 16x8 output, total: 32x64 + float acc[2][8][4]; // [wm][wn][4 regs per mma] #pragma unroll - for (int wm = 0; wm < WARP_TILES_M; ++wm) { + for (int wm = 0; wm < 2; ++wm) { #pragma unroll - for (int wn = 0; wn < WARP_TILES_N; ++wn) { - wmma::fill_fragment(acc[wm][wn], 0.0f); + for (int wn = 0; wn < 8; ++wn) { + acc[wm][wn][0] = 0.0f; + acc[wm][wn][1] = 0.0f; + acc[wm][wn][2] = 0.0f; + acc[wm][wn][3] = 0.0f; } } - const int num_k_tiles = K / BK; + const int num_k_tiles = (K + BK - 1) / BK; + + // Loading patterns + // A: 128x32, each thread loads 16 floats (4 float4) + // B: 32x128, each thread loads 16 floats (4 float4) - // Load helpers auto load_A_async = [&](int stage, int kt) { - const int a_row = tid / 4; - const int a_col = (tid % 4) * 4; + // Each thread loads 16 elements (128*32 / 256 = 16) + // Use vectorized loads: 4 x float4 + const int elements_per_thread = (BM * BK) / NUM_THREADS; // 16 + const int base_idx = tid * elements_per_thread; #pragma unroll - for (int i = 0; i < 2; ++i) { - int row = a_row + i * 64; - int gm = cta_m + row; - int gk = kt * BK + a_col; + for (int i = 0; i < elements_per_thread; i += 4) { + int idx = base_idx + i; + int row = idx / BK; + int col = idx % BK; - if (gm < M && gk < K) { - cp_async_16(&smA[stage][row][a_col], &A[gm * K + gk]); + int gm = cta_m + row; + int gk = kt * BK + col; + + if (gm < M && gk + 3 < K) { + cp_async_16(&smA[stage][row][col], &A[gm * K + gk]); + } else { + // Zero-fill for out-of-bounds + smA[stage][row][col] = 0.0f; + smA[stage][row][col+1] = 0.0f; + smA[stage][row][col+2] = 0.0f; + smA[stage][row][col+3] = 0.0f; } } }; auto load_B_async = [&](int stage, int kt) { - const int b_row_ld = tid / 32; - const int b_col_ld = (tid % 32) * 4; + const int elements_per_thread = (BK * BN) / NUM_THREADS; // 16 + const int base_idx = tid * elements_per_thread; #pragma unroll - for (int i = 0; i < 2; ++i) { - int k = b_row_ld + i * 8; - int gk = kt * BK + k; - int gn = cta_n + b_col_ld; - - if (gk < K && gn < N) { - cp_async_16(&smB[stage][k][b_col_ld], &B[gk * N + gn]); + for (int i = 0; i < elements_per_thread; i += 4) { + int idx = base_idx + i; + int row = idx / BN; + int col = idx % BN; + + int gk = kt * BK + row; + int gn = cta_n + col; + + if (gk < K && gn + 3 < N) { + cp_async_16(&smB[stage][row][col], &B[gk * N + gn]); + } else { + smB[stage][row][col] = 0.0f; + smB[stage][row][col+1] = 0.0f; + smB[stage][row][col+2] = 0.0f; + smB[stage][row][col+3] = 0.0f; } } }; - // Prologue + // Prologue: load first tile load_A_async(0, 0); load_B_async(0, 0); cp_async_commit(); @@ -156,30 +236,68 @@ sgemm_tf32_v2_kernel( int curr = kt & 1; int next = curr ^ 1; + // Prefetch next tile if (kt + 1 < num_k_tiles) { load_A_async(next, kt + 1); load_B_async(next, kt + 1); } cp_async_commit(); - // Process current tile using WMMA + // Compute current tile + // Process BK in chunks of MMA_K (8) #pragma unroll - for (int kk = 0; kk < BK; kk += WMMA_K) { + for (int kk = 0; kk < BK; kk += MMA_K) { + // Load A fragments for this warp (2 x 16x8 tiles) + uint32_t a_frag[2][4]; // 2 tiles, 4 regs each + #pragma unroll - for (int wm = 0; wm < WARP_TILES_M; ++wm) { - int tile_m = warp_m + wm * WMMA_M; + for (int wm = 0; wm < 2; ++wm) { + // A fragment for m16n8k8 row-major: + // a[0] = A[lane/4][lane%4] (rows 0-7, cols 0-3) + // a[1] = A[lane/4+8][lane%4] (rows 8-15, cols 0-3) + // a[2] = A[lane/4][lane%4+4] (rows 0-7, cols 4-7) + // a[3] = A[lane/4+8][lane%4+4] (rows 8-15, cols 4-7) + int a_row = warp_m + wm * MMA_M + (lane / 4); + int a_col = kk + (lane % 4); + + float v0 = smA[curr][a_row][a_col]; // A[row][col] + float v1 = smA[curr][a_row + 8][a_col]; // A[row+8][col] + float v2 = smA[curr][a_row][a_col + 4]; // A[row][col+4] + float v3 = smA[curr][a_row + 8][a_col + 4]; // A[row+8][col+4] + + // Pack as uint32 (TF32 uses same bit pattern as float) + a_frag[wm][0] = __float_as_uint(v0); + a_frag[wm][1] = __float_as_uint(v1); + a_frag[wm][2] = __float_as_uint(v2); + a_frag[wm][3] = __float_as_uint(v3); + } - wmma::fragment a_frag; - wmma::load_matrix_sync(a_frag, &smA[curr][tile_m][kk], BK + A_PAD); + // Load B fragments (8 x 8x8 tiles for 64 columns) + uint32_t b_frag[8][2]; - #pragma unroll - for (int wn = 0; wn < WARP_TILES_N; ++wn) { - int tile_n = warp_n + wn * WMMA_N; + #pragma unroll + for (int wn = 0; wn < 8; ++wn) { + int b_row = kk + (lane % 4); + int b_col = warp_n + wn * MMA_N + (lane / 4); - wmma::fragment b_frag; - wmma::load_matrix_sync(b_frag, &smB[curr][kk][tile_n], BN + B_PAD); + float v0 = smB[curr][b_row][b_col]; + float v1 = smB[curr][b_row + 4][b_col]; + + b_frag[wn][0] = __float_as_uint(v0); + b_frag[wn][1] = __float_as_uint(v1); + } - wmma::mma_sync(acc[wm][wn], a_frag, b_frag, acc[wm][wn]); + // Execute mma.sync for all combinations + #pragma unroll + for (int wm = 0; wm < 2; ++wm) { + #pragma unroll + for (int wn = 0; wn < 8; ++wn) { + mma_m16n8k8_tf32( + &acc[wm][wn][0], &acc[wm][wn][1], &acc[wm][wn][2], &acc[wm][wn][3], + a_frag[wm][0], a_frag[wm][1], a_frag[wm][2], a_frag[wm][3], + b_frag[wn][0], b_frag[wn][1], + acc[wm][wn][0], acc[wm][wn][1], acc[wm][wn][2], acc[wm][wn][3] + ); } } } @@ -190,14 +308,31 @@ sgemm_tf32_v2_kernel( // Epilogue: write results #pragma unroll - for (int wm = 0; wm < WARP_TILES_M; ++wm) { + for (int wm = 0; wm < 2; ++wm) { #pragma unroll - for (int wn = 0; wn < WARP_TILES_N; ++wn) { - int tile_m = cta_m + warp_m + wm * WMMA_M; - int tile_n = cta_n + warp_n + wn * WMMA_N; - - if (tile_m < M && tile_n < N) { - wmma::store_matrix_sync(&C[tile_m * N + tile_n], acc[wm][wn], N, wmma::mem_row_major); + for (int wn = 0; wn < 8; ++wn) { + // Output position for this mma tile + int out_m = cta_m + warp_m + wm * MMA_M; + int out_n = cta_n + warp_n + wn * MMA_N; + + // C fragment layout for m16n8k8: + // c[0] = C[lane/4][(lane%4)*2] + // c[1] = C[lane/4][(lane%4)*2 + 1] + // c[2] = C[lane/4 + 8][(lane%4)*2] + // c[3] = C[lane/4 + 8][(lane%4)*2 + 1] + + int row0 = out_m + (lane / 4); + int row1 = out_m + (lane / 4) + 8; + int col0 = out_n + (lane % 4) * 2; + int col1 = col0 + 1; + + if (row0 < M) { + if (col0 < N) C[row0 * N + col0] = acc[wm][wn][0]; + if (col1 < N) C[row0 * N + col1] = acc[wm][wn][1]; + } + if (row1 < M) { + if (col0 < N) C[row1 * N + col0] = acc[wm][wn][2]; + if (col1 < N) C[row1 * N + col1] = acc[wm][wn][3]; } } } From d9c2d732eb486f1c1a7a131af5934785f46942de Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 00:33:35 +0900 Subject: [PATCH 08/24] wip(tf32-v2): WMMA BK=16 with dynamic smem Benchmark results (RTX 3090 Ti): - 2048x2048: 10.94 TFLOPS (cuBLAS: 30.60) - 4096x4096: 18.85 TFLOPS (cuBLAS: 35.49) - 8192x8192: 24.93 TFLOPS (cuBLAS: 41.79) Correctness: PASS --- native/ops/matmul_f32_tf32_v2.cuh | 316 ++++++++++-------------------- 1 file changed, 105 insertions(+), 211 deletions(-) diff --git a/native/ops/matmul_f32_tf32_v2.cuh b/native/ops/matmul_f32_tf32_v2.cuh index fe8e5be..e2d2682 100644 --- a/native/ops/matmul_f32_tf32_v2.cuh +++ b/native/ops/matmul_f32_tf32_v2.cuh @@ -1,13 +1,13 @@ /** - * TF32 TensorCore GEMM v2 - ldmatrix + swizzled shared memory + * TF32 TensorCore GEMM v2 - Optimized PTX implementation * * Target: 90%+ of cuBLAS performance (37.6+ TFLOPS on RTX 3090 Ti) * - * Optimizations: - * 1. ldmatrix.sync for efficient shared→register transfers - * 2. Swizzled shared memory to eliminate bank conflicts - * 3. Aggressive register blocking (more mma per smem load) - * 4. 2-stage double buffering with cp.async + * Configuration: + * - 128x128 CTA tile, BK=16 + * - 4x2 warps (256 threads) + * - Each warp: 32x64 output (2x4 WMMA 16x16 tiles) + * - 2-stage double buffering with cp.async */ #pragma once @@ -15,6 +15,8 @@ #include #include +using namespace nvcuda; + namespace pygpukit { namespace ops { namespace tf32_v2 { @@ -25,32 +27,27 @@ namespace tf32_v2 { constexpr int BM = 128; constexpr int BN = 128; -constexpr int BK = 32; // Larger K-tile for better compute intensity +constexpr int BK = 16; -// PTX mma dimensions: m16n8k8 -constexpr int MMA_M = 16; -constexpr int MMA_N = 8; -constexpr int MMA_K = 8; +// WMMA dimensions +constexpr int WMMA_M = 16; +constexpr int WMMA_N = 16; +constexpr int WMMA_K = 8; -// Warp configuration: 4x2 warps per block (256 threads) +// Warp configuration: 4x2 warps constexpr int WARPS_M = 4; constexpr int WARPS_N = 2; constexpr int NUM_THREADS = WARPS_M * WARPS_N * 32; -// Each warp computes 32x64 output (2x8 = 16 mma tiles) -constexpr int WARP_M = 32; // 2 x MMA_M -constexpr int WARP_N = 64; // 8 x MMA_N +// Each warp computes 2x4 WMMA tiles (32x64 output) +constexpr int WARP_TILES_M = 2; +constexpr int WARP_TILES_N = 4; constexpr int STAGES = 2; - -// Shared memory with swizzle (XOR-based) -// For TF32, each element is 4 bytes. A 128-element row would have 32 banks accessed. -// Swizzle pattern: smem[row][col] -> smem[row][col ^ ((row % 4) * 4)] -constexpr int SMEM_A_STRIDE = BK + 4; // Padding to avoid bank conflicts -constexpr int SMEM_B_STRIDE = BN + 4; +constexpr int SMEM_PAD = 8; // Padding for bank conflict avoidance // ============================================================================ -// Inline PTX helpers +// cp.async helpers // ============================================================================ __device__ __forceinline__ uint32_t smem_u32(const void* ptr) { @@ -77,55 +74,8 @@ __device__ __forceinline__ void cp_async_wait_0() { asm volatile("cp.async.wait_group 0;"); } -// ldmatrix: load 4 x 8x8 matrices from shared memory -__device__ __forceinline__ void ldmatrix_x4(uint32_t* r0, uint32_t* r1, uint32_t* r2, uint32_t* r3, const void* smem) { - uint32_t addr = smem_u32(smem); - asm volatile( - "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" - : "=r"(*r0), "=r"(*r1), "=r"(*r2), "=r"(*r3) - : "r"(addr) - ); -} - -// ldmatrix: load 2 x 8x8 matrices -__device__ __forceinline__ void ldmatrix_x2(uint32_t* r0, uint32_t* r1, const void* smem) { - uint32_t addr = smem_u32(smem); - asm volatile( - "ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];" - : "=r"(*r0), "=r"(*r1) - : "r"(addr) - ); -} - -// ldmatrix for transposed B (col-major in shared) -__device__ __forceinline__ void ldmatrix_x2_trans(uint32_t* r0, uint32_t* r1, const void* smem) { - uint32_t addr = smem_u32(smem); - asm volatile( - "ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];" - : "=r"(*r0), "=r"(*r1) - : "r"(addr) - ); -} - -// TF32 mma.sync: m16n8k8 -__device__ __forceinline__ void mma_m16n8k8_tf32( - float* d0, float* d1, float* d2, float* d3, - uint32_t a0, uint32_t a1, uint32_t a2, uint32_t a3, - uint32_t b0, uint32_t b1, - float c0, float c1, float c2, float c3 -) { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" - : "=f"(*d0), "=f"(*d1), "=f"(*d2), "=f"(*d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "r"(b0), "r"(b1), - "f"(c0), "f"(c1), "f"(c2), "f"(c3) - ); -} - // ============================================================================ -// Main Kernel +// Main Kernel using WMMA API (simpler and often equally fast) // ============================================================================ __global__ void __launch_bounds__(256, 2) @@ -142,91 +92,77 @@ sgemm_tf32_v2_kernel( const int cta_m = blockIdx.y * BM; const int cta_n = blockIdx.x * BN; - // Warp position in CTA - const int warp_row = warp_id / WARPS_N; // 0-3 - const int warp_col = warp_id % WARPS_N; // 0-1 + const int warp_row = warp_id / WARPS_N; + const int warp_col = warp_id % WARPS_N; + + const int warp_m = warp_row * (WARP_TILES_M * WMMA_M); // 0, 32, 64, 96 + const int warp_n = warp_col * (WARP_TILES_N * WMMA_N); // 0, 64 + + // Shared memory: row-major with padding + extern __shared__ float smem[]; + float* smA = smem; // [STAGES][BM][BK + SMEM_PAD] + float* smB = smA + STAGES * BM * (BK + SMEM_PAD); // [STAGES][BK][BN + SMEM_PAD] - const int warp_m = warp_row * WARP_M; // 0, 32, 64, 96 - const int warp_n = warp_col * WARP_N; // 0, 64 + const int A_stride = BK + SMEM_PAD; + const int B_stride = BN + SMEM_PAD; - // Shared memory - __shared__ float smA[STAGES][BM][SMEM_A_STRIDE]; - __shared__ float smB[STAGES][BK][SMEM_B_STRIDE]; + // WMMA fragments for accumulators + wmma::fragment acc[WARP_TILES_M][WARP_TILES_N]; - // Accumulators: 2x8 = 16 mma tiles per warp - // Each mma produces 16x8 output, total: 32x64 - float acc[2][8][4]; // [wm][wn][4 regs per mma] #pragma unroll - for (int wm = 0; wm < 2; ++wm) { + for (int wm = 0; wm < WARP_TILES_M; ++wm) { #pragma unroll - for (int wn = 0; wn < 8; ++wn) { - acc[wm][wn][0] = 0.0f; - acc[wm][wn][1] = 0.0f; - acc[wm][wn][2] = 0.0f; - acc[wm][wn][3] = 0.0f; + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + wmma::fill_fragment(acc[wm][wn], 0.0f); } } - const int num_k_tiles = (K + BK - 1) / BK; + const int num_k_tiles = K / BK; - // Loading patterns - // A: 128x32, each thread loads 16 floats (4 float4) - // B: 32x128, each thread loads 16 floats (4 float4) + // Load helpers + // A: 128x16, each thread loads 8 floats (2 float4) + // B: 16x128, each thread loads 8 floats (2 float4) + auto load_A = [&](int stage, int kt) { + float* dst = smA + stage * BM * A_stride; - auto load_A_async = [&](int stage, int kt) { - // Each thread loads 16 elements (128*32 / 256 = 16) - // Use vectorized loads: 4 x float4 - const int elements_per_thread = (BM * BK) / NUM_THREADS; // 16 - const int base_idx = tid * elements_per_thread; + // Thread mapping: 256 threads load 128x16 = 2048 elements + // Each thread loads 8 elements (2 float4) + const int row0 = tid / 2; // 0-127 + const int col0 = (tid & 1) * 8; // 0 or 8 - #pragma unroll - for (int i = 0; i < elements_per_thread; i += 4) { - int idx = base_idx + i; - int row = idx / BK; - int col = idx % BK; - - int gm = cta_m + row; - int gk = kt * BK + col; - - if (gm < M && gk + 3 < K) { - cp_async_16(&smA[stage][row][col], &A[gm * K + gk]); - } else { - // Zero-fill for out-of-bounds - smA[stage][row][col] = 0.0f; - smA[stage][row][col+1] = 0.0f; - smA[stage][row][col+2] = 0.0f; - smA[stage][row][col+3] = 0.0f; - } + int gm = cta_m + row0; + int gk = kt * BK + col0; + + if (gm < M && gk + 3 < K) { + cp_async_16(&dst[row0 * A_stride + col0], &A[gm * K + gk]); + } + if (gm < M && gk + 7 < K) { + cp_async_16(&dst[row0 * A_stride + col0 + 4], &A[gm * K + gk + 4]); } }; - auto load_B_async = [&](int stage, int kt) { - const int elements_per_thread = (BK * BN) / NUM_THREADS; // 16 - const int base_idx = tid * elements_per_thread; + auto load_B = [&](int stage, int kt) { + float* dst = smB + stage * BK * B_stride; - #pragma unroll - for (int i = 0; i < elements_per_thread; i += 4) { - int idx = base_idx + i; - int row = idx / BN; - int col = idx % BN; - - int gk = kt * BK + row; - int gn = cta_n + col; - - if (gk < K && gn + 3 < N) { - cp_async_16(&smB[stage][row][col], &B[gk * N + gn]); - } else { - smB[stage][row][col] = 0.0f; - smB[stage][row][col+1] = 0.0f; - smB[stage][row][col+2] = 0.0f; - smB[stage][row][col+3] = 0.0f; - } + // Thread mapping: 256 threads load 16x128 = 2048 elements + // Each thread loads 8 elements (2 float4) + const int row0 = tid / 16; // 0-15 + const int col0 = (tid & 15) * 8; // 0, 8, 16, ..., 120 + + int gk = kt * BK + row0; + int gn = cta_n + col0; + + if (gk < K && gn + 3 < N) { + cp_async_16(&dst[row0 * B_stride + col0], &B[gk * N + gn]); + } + if (gk < K && gn + 7 < N) { + cp_async_16(&dst[row0 * B_stride + col0 + 4], &B[gk * N + gn + 4]); } }; - // Prologue: load first tile - load_A_async(0, 0); - load_B_async(0, 0); + // Prologue + load_A(0, 0); + load_B(0, 0); cp_async_commit(); cp_async_wait_0(); __syncthreads(); @@ -236,68 +172,38 @@ sgemm_tf32_v2_kernel( int curr = kt & 1; int next = curr ^ 1; - // Prefetch next tile if (kt + 1 < num_k_tiles) { - load_A_async(next, kt + 1); - load_B_async(next, kt + 1); + load_A(next, kt + 1); + load_B(next, kt + 1); } cp_async_commit(); - // Compute current tile - // Process BK in chunks of MMA_K (8) + float* A_tile = smA + curr * BM * A_stride; + float* B_tile = smB + curr * BK * B_stride; + + // Process K dimension in chunks of WMMA_K (8) #pragma unroll - for (int kk = 0; kk < BK; kk += MMA_K) { - // Load A fragments for this warp (2 x 16x8 tiles) - uint32_t a_frag[2][4]; // 2 tiles, 4 regs each + for (int kk = 0; kk < BK; kk += WMMA_K) { + // Load A fragments + wmma::fragment a_frag[WARP_TILES_M]; #pragma unroll - for (int wm = 0; wm < 2; ++wm) { - // A fragment for m16n8k8 row-major: - // a[0] = A[lane/4][lane%4] (rows 0-7, cols 0-3) - // a[1] = A[lane/4+8][lane%4] (rows 8-15, cols 0-3) - // a[2] = A[lane/4][lane%4+4] (rows 0-7, cols 4-7) - // a[3] = A[lane/4+8][lane%4+4] (rows 8-15, cols 4-7) - int a_row = warp_m + wm * MMA_M + (lane / 4); - int a_col = kk + (lane % 4); - - float v0 = smA[curr][a_row][a_col]; // A[row][col] - float v1 = smA[curr][a_row + 8][a_col]; // A[row+8][col] - float v2 = smA[curr][a_row][a_col + 4]; // A[row][col+4] - float v3 = smA[curr][a_row + 8][a_col + 4]; // A[row+8][col+4] - - // Pack as uint32 (TF32 uses same bit pattern as float) - a_frag[wm][0] = __float_as_uint(v0); - a_frag[wm][1] = __float_as_uint(v1); - a_frag[wm][2] = __float_as_uint(v2); - a_frag[wm][3] = __float_as_uint(v3); + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + int tile_m = warp_m + wm * WMMA_M; + wmma::load_matrix_sync(a_frag[wm], &A_tile[tile_m * A_stride + kk], A_stride); } - // Load B fragments (8 x 8x8 tiles for 64 columns) - uint32_t b_frag[8][2]; - + // Load B fragments and compute #pragma unroll - for (int wn = 0; wn < 8; ++wn) { - int b_row = kk + (lane % 4); - int b_col = warp_n + wn * MMA_N + (lane / 4); + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_n = warp_n + wn * WMMA_N; - float v0 = smB[curr][b_row][b_col]; - float v1 = smB[curr][b_row + 4][b_col]; - - b_frag[wn][0] = __float_as_uint(v0); - b_frag[wn][1] = __float_as_uint(v1); - } + wmma::fragment b_frag; + wmma::load_matrix_sync(b_frag, &B_tile[kk * B_stride + tile_n], B_stride); - // Execute mma.sync for all combinations - #pragma unroll - for (int wm = 0; wm < 2; ++wm) { #pragma unroll - for (int wn = 0; wn < 8; ++wn) { - mma_m16n8k8_tf32( - &acc[wm][wn][0], &acc[wm][wn][1], &acc[wm][wn][2], &acc[wm][wn][3], - a_frag[wm][0], a_frag[wm][1], a_frag[wm][2], a_frag[wm][3], - b_frag[wn][0], b_frag[wn][1], - acc[wm][wn][0], acc[wm][wn][1], acc[wm][wn][2], acc[wm][wn][3] - ); + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + wmma::mma_sync(acc[wm][wn], a_frag[wm], b_frag, acc[wm][wn]); } } } @@ -308,31 +214,14 @@ sgemm_tf32_v2_kernel( // Epilogue: write results #pragma unroll - for (int wm = 0; wm < 2; ++wm) { + for (int wm = 0; wm < WARP_TILES_M; ++wm) { #pragma unroll - for (int wn = 0; wn < 8; ++wn) { - // Output position for this mma tile - int out_m = cta_m + warp_m + wm * MMA_M; - int out_n = cta_n + warp_n + wn * MMA_N; - - // C fragment layout for m16n8k8: - // c[0] = C[lane/4][(lane%4)*2] - // c[1] = C[lane/4][(lane%4)*2 + 1] - // c[2] = C[lane/4 + 8][(lane%4)*2] - // c[3] = C[lane/4 + 8][(lane%4)*2 + 1] - - int row0 = out_m + (lane / 4); - int row1 = out_m + (lane / 4) + 8; - int col0 = out_n + (lane % 4) * 2; - int col1 = col0 + 1; - - if (row0 < M) { - if (col0 < N) C[row0 * N + col0] = acc[wm][wn][0]; - if (col1 < N) C[row0 * N + col1] = acc[wm][wn][1]; - } - if (row1 < M) { - if (col0 < N) C[row1 * N + col0] = acc[wm][wn][2]; - if (col1 < N) C[row1 * N + col1] = acc[wm][wn][3]; + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_m = cta_m + warp_m + wm * WMMA_M; + int tile_n = cta_n + warp_n + wn * WMMA_N; + + if (tile_m < M && tile_n < N) { + wmma::store_matrix_sync(&C[tile_m * N + tile_n], acc[wm][wn], N, wmma::mem_row_major); } } } @@ -349,7 +238,12 @@ inline cudaError_t launch_sgemm_tf32_v2( ) { dim3 block(256); dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM); - sgemm_tf32_v2_kernel<<>>(A, B, C, M, N, K); + + // Shared memory size + size_t smem_size = STAGES * BM * (BK + SMEM_PAD) * sizeof(float) + + STAGES * BK * (BN + SMEM_PAD) * sizeof(float); + + sgemm_tf32_v2_kernel<<>>(A, B, C, M, N, K); return cudaGetLastError(); } From 465358e1937dcf8c0b1e90ef25cd00c8faaa08d6 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 00:37:40 +0900 Subject: [PATCH 09/24] wip(tf32-v2): 3-stage pipeline - regression due to smem Benchmark results (RTX 3090 Ti): - 2048x2048: 10.83 TFLOPS (cuBLAS: 30.60) - 4096x4096: 18.75 TFLOPS (cuBLAS: 35.49) - 8192x8192: 25.22 TFLOPS (cuBLAS: 41.79) Correctness: PASS Note: 3-stage uses too much smem, reduces occupancy. v1 (2-stage) achieves 30.10 TFLOPS. Need different approach. --- native/ops/matmul_f32_tf32_v2.cuh | 251 +++++++++++++++++------------- 1 file changed, 139 insertions(+), 112 deletions(-) diff --git a/native/ops/matmul_f32_tf32_v2.cuh b/native/ops/matmul_f32_tf32_v2.cuh index e2d2682..40e0f52 100644 --- a/native/ops/matmul_f32_tf32_v2.cuh +++ b/native/ops/matmul_f32_tf32_v2.cuh @@ -1,50 +1,61 @@ /** - * TF32 TensorCore GEMM v2 - Optimized PTX implementation + * TF32 TensorCore GEMM v2 - CUTLASS-inspired optimizations * * Target: 90%+ of cuBLAS performance (37.6+ TFLOPS on RTX 3090 Ti) * - * Configuration: - * - 128x128 CTA tile, BK=16 - * - 4x2 warps (256 threads) - * - Each warp: 32x64 output (2x4 WMMA 16x16 tiles) - * - 2-stage double buffering with cp.async + * Optimizations: + * 1. 3-stage software pipelining with cp.async + * 2. Swizzled shared memory to eliminate bank conflicts + * 3. Register double-buffering for fragments + * 4. Optimized warp-level tiling (2x8 mma per warp) */ #pragma once #include #include -#include - -using namespace nvcuda; namespace pygpukit { namespace ops { namespace tf32_v2 { // ============================================================================ -// Configuration +// Configuration - optimized for RTX 3090 Ti (GA102, SM 8.6) // ============================================================================ constexpr int BM = 128; constexpr int BN = 128; constexpr int BK = 16; -// WMMA dimensions -constexpr int WMMA_M = 16; -constexpr int WMMA_N = 16; -constexpr int WMMA_K = 8; +// MMA tile: m16n8k8 +constexpr int MMA_M = 16; +constexpr int MMA_N = 8; +constexpr int MMA_K = 8; -// Warp configuration: 4x2 warps +// Warp configuration: 4x2 warps = 8 warps = 256 threads constexpr int WARPS_M = 4; constexpr int WARPS_N = 2; -constexpr int NUM_THREADS = WARPS_M * WARPS_N * 32; -// Each warp computes 2x4 WMMA tiles (32x64 output) +// Each warp computes 2x8 MMA tiles = 32x64 output constexpr int WARP_TILES_M = 2; -constexpr int WARP_TILES_N = 4; +constexpr int WARP_TILES_N = 8; + +// Pipeline stages +constexpr int STAGES = 3; + +// Shared memory padding (avoid bank conflicts) +constexpr int A_PAD = 4; +constexpr int B_PAD = 4; + +// ============================================================================ +// Swizzle helpers - XOR-based swizzling for bank conflict avoidance +// ============================================================================ -constexpr int STAGES = 2; -constexpr int SMEM_PAD = 8; // Padding for bank conflict avoidance +// Swizzle pattern: col ^ ((row >> 1) & 3) for 4-byte elements +__device__ __forceinline__ int swizzle_offset(int row, int col, int stride) { + // Simple swizzle: XOR lower bits of row into column + int swizzled_col = col ^ ((row & 3) << 2); // XOR with row[1:0] << 2 + return row * stride + swizzled_col; +} // ============================================================================ // cp.async helpers @@ -70,12 +81,17 @@ __device__ __forceinline__ void cp_async_commit() { asm volatile("cp.async.commit_group;"); } +template +__device__ __forceinline__ void cp_async_wait() { + asm volatile("cp.async.wait_group %0;" :: "n"(N)); +} + __device__ __forceinline__ void cp_async_wait_0() { asm volatile("cp.async.wait_group 0;"); } // ============================================================================ -// Main Kernel using WMMA API (simpler and often equally fast) +// Main Kernel - 3-stage pipelined GEMM // ============================================================================ __global__ void __launch_bounds__(256, 2) @@ -92,137 +108,153 @@ sgemm_tf32_v2_kernel( const int cta_m = blockIdx.y * BM; const int cta_n = blockIdx.x * BN; - const int warp_row = warp_id / WARPS_N; - const int warp_col = warp_id % WARPS_N; - - const int warp_m = warp_row * (WARP_TILES_M * WMMA_M); // 0, 32, 64, 96 - const int warp_n = warp_col * (WARP_TILES_N * WMMA_N); // 0, 64 + const int warp_row = warp_id / WARPS_N; // 0-3 + const int warp_col = warp_id % WARPS_N; // 0-1 - // Shared memory: row-major with padding - extern __shared__ float smem[]; - float* smA = smem; // [STAGES][BM][BK + SMEM_PAD] - float* smB = smA + STAGES * BM * (BK + SMEM_PAD); // [STAGES][BK][BN + SMEM_PAD] + const int warp_m = warp_row * (WARP_TILES_M * MMA_M); // 0, 32, 64, 96 + const int warp_n = warp_col * (WARP_TILES_N * MMA_N); // 0, 64 - const int A_stride = BK + SMEM_PAD; - const int B_stride = BN + SMEM_PAD; + // Shared memory for 3 stages + __shared__ float smA[STAGES][BM][BK + A_PAD]; + __shared__ float smB[STAGES][BK][BN + B_PAD]; - // WMMA fragments for accumulators - wmma::fragment acc[WARP_TILES_M][WARP_TILES_N]; - - #pragma unroll - for (int wm = 0; wm < WARP_TILES_M; ++wm) { - #pragma unroll - for (int wn = 0; wn < WARP_TILES_N; ++wn) { - wmma::fill_fragment(acc[wm][wn], 0.0f); - } - } + // Accumulators: 2x8 MMA tiles per warp, 4 regs per tile + float acc[WARP_TILES_M][WARP_TILES_N][4] = {}; const int num_k_tiles = K / BK; - // Load helpers - // A: 128x16, each thread loads 8 floats (2 float4) - // B: 16x128, each thread loads 8 floats (2 float4) - auto load_A = [&](int stage, int kt) { - float* dst = smA + stage * BM * A_stride; + // Fragment index mappings (verified via dump_c_fragment.cu) + const int a_row_base = lane / 4; // 0-7 + const int a_col_base = lane % 4; // 0-3 + const int b_row_base = lane % 4; // 0-3 + const int b_col = lane / 4; // 0-7 + const int c_row_base = lane / 4; + const int c_col_base = (lane % 4) * 2; - // Thread mapping: 256 threads load 128x16 = 2048 elements - // Each thread loads 8 elements (2 float4) - const int row0 = tid / 2; // 0-127 - const int col0 = (tid & 1) * 8; // 0 or 8 + // ====== Load helpers ====== + auto load_A_async = [&](int stage, int kt) { + const int a_row = tid / 4; // 0-63 + const int a_col = (tid % 4) * 4; // 0, 4, 8, 12 - int gm = cta_m + row0; - int gk = kt * BK + col0; + #pragma unroll + for (int i = 0; i < 2; ++i) { + int row = a_row + i * 64; + int gm = cta_m + row; + int gk = kt * BK + a_col; - if (gm < M && gk + 3 < K) { - cp_async_16(&dst[row0 * A_stride + col0], &A[gm * K + gk]); - } - if (gm < M && gk + 7 < K) { - cp_async_16(&dst[row0 * A_stride + col0 + 4], &A[gm * K + gk + 4]); + if (gm < M && gk < K) { + cp_async_16(&smA[stage][row][a_col], &A[gm * K + gk]); + } } }; - auto load_B = [&](int stage, int kt) { - float* dst = smB + stage * BK * B_stride; - - // Thread mapping: 256 threads load 16x128 = 2048 elements - // Each thread loads 8 elements (2 float4) - const int row0 = tid / 16; // 0-15 - const int col0 = (tid & 15) * 8; // 0, 8, 16, ..., 120 + auto load_B_async = [&](int stage, int kt) { + const int b_row_ld = tid / 32; // 0-7 + const int b_col_ld = (tid % 32) * 4; // 0, 4, ..., 124 - int gk = kt * BK + row0; - int gn = cta_n + col0; + #pragma unroll + for (int i = 0; i < 2; ++i) { + int k = b_row_ld + i * 8; + int gk = kt * BK + k; + int gn = cta_n + b_col_ld; - if (gk < K && gn + 3 < N) { - cp_async_16(&dst[row0 * B_stride + col0], &B[gk * N + gn]); - } - if (gk < K && gn + 7 < N) { - cp_async_16(&dst[row0 * B_stride + col0 + 4], &B[gk * N + gn + 4]); + if (gk < K && gn < N) { + cp_async_16(&smB[stage][k][b_col_ld], &B[gk * N + gn]); + } } }; - // Prologue - load_A(0, 0); - load_B(0, 0); + // ====== Prologue: fill pipeline (stages 0, 1) ====== + load_A_async(0, 0); + load_B_async(0, 0); cp_async_commit(); - cp_async_wait_0(); + + if (num_k_tiles > 1) { + load_A_async(1, 1); + load_B_async(1, 1); + } + cp_async_commit(); + + // Wait for stage 0 + cp_async_wait<1>(); __syncthreads(); - // Main loop + // ====== Main loop with 3-stage pipelining ====== for (int kt = 0; kt < num_k_tiles; ++kt) { - int curr = kt & 1; - int next = curr ^ 1; + int curr = kt % STAGES; - if (kt + 1 < num_k_tiles) { - load_A(next, kt + 1); - load_B(next, kt + 1); + // Prefetch for kt+2 into stage (kt+2) % STAGES + if (kt + 2 < num_k_tiles) { + int prefetch_stage = (kt + 2) % STAGES; + load_A_async(prefetch_stage, kt + 2); + load_B_async(prefetch_stage, kt + 2); } cp_async_commit(); - float* A_tile = smA + curr * BM * A_stride; - float* B_tile = smB + curr * BK * B_stride; - - // Process K dimension in chunks of WMMA_K (8) + // ====== Compute current tile ====== #pragma unroll - for (int kk = 0; kk < BK; kk += WMMA_K) { - // Load A fragments - wmma::fragment a_frag[WARP_TILES_M]; + for (int kk = 0; kk < BK; kk += MMA_K) { + // Register buffers for A fragments (hoist outside wn loop) + float a_reg[WARP_TILES_M][4]; #pragma unroll for (int wm = 0; wm < WARP_TILES_M; ++wm) { - int tile_m = warp_m + wm * WMMA_M; - wmma::load_matrix_sync(a_frag[wm], &A_tile[tile_m * A_stride + kk], A_stride); + int tile_m = warp_m + wm * MMA_M; + a_reg[wm][0] = smA[curr][tile_m + a_row_base][kk + a_col_base]; + a_reg[wm][1] = smA[curr][tile_m + a_row_base + 8][kk + a_col_base]; + a_reg[wm][2] = smA[curr][tile_m + a_row_base][kk + a_col_base + 4]; + a_reg[wm][3] = smA[curr][tile_m + a_row_base + 8][kk + a_col_base + 4]; } - // Load B fragments and compute + // Process all B tiles with pre-loaded A #pragma unroll for (int wn = 0; wn < WARP_TILES_N; ++wn) { - int tile_n = warp_n + wn * WMMA_N; - - wmma::fragment b_frag; - wmma::load_matrix_sync(b_frag, &B_tile[kk * B_stride + tile_n], B_stride); + int tile_n = warp_n + wn * MMA_N; + float b0 = smB[curr][kk + b_row_base][tile_n + b_col]; + float b1 = smB[curr][kk + b_row_base + 4][tile_n + b_col]; #pragma unroll for (int wm = 0; wm < WARP_TILES_M; ++wm) { - wmma::mma_sync(acc[wm][wn], a_frag[wm], b_frag, acc[wm][wn]); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%0, %1, %2, %3};" + : "+f"(acc[wm][wn][0]), "+f"(acc[wm][wn][1]), + "+f"(acc[wm][wn][2]), "+f"(acc[wm][wn][3]) + : "r"(__float_as_uint(a_reg[wm][0])), "r"(__float_as_uint(a_reg[wm][1])), + "r"(__float_as_uint(a_reg[wm][2])), "r"(__float_as_uint(a_reg[wm][3])), + "r"(__float_as_uint(b0)), "r"(__float_as_uint(b1)) + ); } } } - cp_async_wait_0(); - __syncthreads(); + // Wait for next stage before moving on + if (kt + 1 < num_k_tiles) { + cp_async_wait<1>(); + __syncthreads(); + } } - // Epilogue: write results + // ====== Epilogue: store results ====== #pragma unroll for (int wm = 0; wm < WARP_TILES_M; ++wm) { #pragma unroll for (int wn = 0; wn < WARP_TILES_N; ++wn) { - int tile_m = cta_m + warp_m + wm * WMMA_M; - int tile_n = cta_n + warp_n + wn * WMMA_N; - - if (tile_m < M && tile_n < N) { - wmma::store_matrix_sync(&C[tile_m * N + tile_n], acc[wm][wn], N, wmma::mem_row_major); - } + int tile_m = cta_m + warp_m + wm * MMA_M; + int tile_n = cta_n + warp_n + wn * MMA_N; + + int out_row0 = tile_m + c_row_base; + int out_row1 = tile_m + c_row_base + 8; + int out_col0 = tile_n + c_col_base; + int out_col1 = tile_n + c_col_base + 1; + + if (out_row0 < M && out_col0 < N) C[out_row0 * N + out_col0] = acc[wm][wn][0]; + if (out_row0 < M && out_col1 < N) C[out_row0 * N + out_col1] = acc[wm][wn][1]; + if (out_row1 < M && out_col0 < N) C[out_row1 * N + out_col0] = acc[wm][wn][2]; + if (out_row1 < M && out_col1 < N) C[out_row1 * N + out_col1] = acc[wm][wn][3]; } } } @@ -238,12 +270,7 @@ inline cudaError_t launch_sgemm_tf32_v2( ) { dim3 block(256); dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM); - - // Shared memory size - size_t smem_size = STAGES * BM * (BK + SMEM_PAD) * sizeof(float) + - STAGES * BK * (BN + SMEM_PAD) * sizeof(float); - - sgemm_tf32_v2_kernel<<>>(A, B, C, M, N, K); + sgemm_tf32_v2_kernel<<>>(A, B, C, M, N, K); return cudaGetLastError(); } From dd8842aa888306fbbca2495c31c2d91bcc7ab049 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 00:41:25 +0900 Subject: [PATCH 10/24] wip(tf32-v2): v1 baseline copy with float2 stores Benchmark results (RTX 3090 Ti): - 2048x2048: 11.47 TFLOPS (cuBLAS: 30.60) - 4096x4096: 21.30 TFLOPS (cuBLAS: 35.49) - 8192x8192: 29.94 TFLOPS (cuBLAS: 41.79) Correctness: PASS Matches v1 baseline (30.10 TFLOPS). Need 25% improvement to reach 90% cuBLAS target. --- native/ops/matmul_f32_tf32_v2.cuh | 176 ++++++++++-------------------- 1 file changed, 60 insertions(+), 116 deletions(-) diff --git a/native/ops/matmul_f32_tf32_v2.cuh b/native/ops/matmul_f32_tf32_v2.cuh index 40e0f52..a57b95c 100644 --- a/native/ops/matmul_f32_tf32_v2.cuh +++ b/native/ops/matmul_f32_tf32_v2.cuh @@ -1,13 +1,9 @@ /** - * TF32 TensorCore GEMM v2 - CUTLASS-inspired optimizations + * TF32 TensorCore GEMM v2 - Exact v1 copy with BK=16 as baseline * * Target: 90%+ of cuBLAS performance (37.6+ TFLOPS on RTX 3090 Ti) * - * Optimizations: - * 1. 3-stage software pipelining with cp.async - * 2. Swizzled shared memory to eliminate bank conflicts - * 3. Register double-buffering for fragments - * 4. Optimized warp-level tiling (2x8 mma per warp) + * This is a copy of v1's proven kernel. We'll optimize from here. */ #pragma once @@ -18,49 +14,25 @@ namespace pygpukit { namespace ops { namespace tf32_v2 { -// ============================================================================ -// Configuration - optimized for RTX 3090 Ti (GA102, SM 8.6) -// ============================================================================ - constexpr int BM = 128; constexpr int BN = 128; constexpr int BK = 16; -// MMA tile: m16n8k8 constexpr int MMA_M = 16; constexpr int MMA_N = 8; constexpr int MMA_K = 8; -// Warp configuration: 4x2 warps = 8 warps = 256 threads constexpr int WARPS_M = 4; constexpr int WARPS_N = 2; - -// Each warp computes 2x8 MMA tiles = 32x64 output constexpr int WARP_TILES_M = 2; constexpr int WARP_TILES_N = 8; -// Pipeline stages -constexpr int STAGES = 3; - -// Shared memory padding (avoid bank conflicts) constexpr int A_PAD = 4; constexpr int B_PAD = 4; -// ============================================================================ -// Swizzle helpers - XOR-based swizzling for bank conflict avoidance -// ============================================================================ - -// Swizzle pattern: col ^ ((row >> 1) & 3) for 4-byte elements -__device__ __forceinline__ int swizzle_offset(int row, int col, int stride) { - // Simple swizzle: XOR lower bits of row into column - int swizzled_col = col ^ ((row & 3) << 2); // XOR with row[1:0] << 2 - return row * stride + swizzled_col; -} - -// ============================================================================ +// ============================================================ // cp.async helpers -// ============================================================================ - +// ============================================================ __device__ __forceinline__ uint32_t smem_u32(const void* ptr) { uint32_t addr; asm volatile( @@ -74,26 +46,23 @@ __device__ __forceinline__ uint32_t smem_u32(const void* ptr) { __device__ __forceinline__ void cp_async_16(void* smem, const void* gmem) { uint32_t addr = smem_u32(smem); - asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" :: "r"(addr), "l"(gmem)); + asm volatile( + "cp.async.cg.shared.global [%0], [%1], 16;" + :: "r"(addr), "l"(gmem) + ); } __device__ __forceinline__ void cp_async_commit() { asm volatile("cp.async.commit_group;"); } -template -__device__ __forceinline__ void cp_async_wait() { - asm volatile("cp.async.wait_group %0;" :: "n"(N)); -} - __device__ __forceinline__ void cp_async_wait_0() { asm volatile("cp.async.wait_group 0;"); } -// ============================================================================ -// Main Kernel - 3-stage pipelined GEMM -// ============================================================================ - +// ============================================================ +// Main kernel - copy of v1 for comparison +// ============================================================ __global__ void __launch_bounds__(256, 2) sgemm_tf32_v2_kernel( const float* __restrict__ A, @@ -108,40 +77,36 @@ sgemm_tf32_v2_kernel( const int cta_m = blockIdx.y * BM; const int cta_n = blockIdx.x * BN; - const int warp_row = warp_id / WARPS_N; // 0-3 - const int warp_col = warp_id % WARPS_N; // 0-1 + const int warp_row = warp_id / WARPS_N; + const int warp_col = warp_id % WARPS_N; - const int warp_m = warp_row * (WARP_TILES_M * MMA_M); // 0, 32, 64, 96 - const int warp_n = warp_col * (WARP_TILES_N * MMA_N); // 0, 64 + const int warp_m = warp_row * (WARP_TILES_M * MMA_M); + const int warp_n = warp_col * (WARP_TILES_N * MMA_N); - // Shared memory for 3 stages - __shared__ float smA[STAGES][BM][BK + A_PAD]; - __shared__ float smB[STAGES][BK][BN + B_PAD]; + __shared__ float smA[2][BM][BK + A_PAD]; + __shared__ float smB[2][BK][BN + B_PAD]; - // Accumulators: 2x8 MMA tiles per warp, 4 regs per tile float acc[WARP_TILES_M][WARP_TILES_N][4] = {}; const int num_k_tiles = K / BK; - // Fragment index mappings (verified via dump_c_fragment.cu) - const int a_row_base = lane / 4; // 0-7 - const int a_col_base = lane % 4; // 0-3 - const int b_row_base = lane % 4; // 0-3 - const int b_col = lane / 4; // 0-7 + // Fragment index mappings + const int a_row_base = lane / 4; + const int a_col_base = lane % 4; + const int b_row_base = lane % 4; + const int b_col = lane / 4; const int c_row_base = lane / 4; const int c_col_base = (lane % 4) * 2; - // ====== Load helpers ====== + // Load helpers - exactly like v1 auto load_A_async = [&](int stage, int kt) { - const int a_row = tid / 4; // 0-63 - const int a_col = (tid % 4) * 4; // 0, 4, 8, 12 - + const int a_row = tid / 4; + const int a_col = (tid % 4) * 4; #pragma unroll for (int i = 0; i < 2; ++i) { int row = a_row + i * 64; int gm = cta_m + row; int gk = kt * BK + a_col; - if (gm < M && gk < K) { cp_async_16(&smA[stage][row][a_col], &A[gm * K + gk]); } @@ -149,72 +114,53 @@ sgemm_tf32_v2_kernel( }; auto load_B_async = [&](int stage, int kt) { - const int b_row_ld = tid / 32; // 0-7 - const int b_col_ld = (tid % 32) * 4; // 0, 4, ..., 124 - + const int b_row = tid / 32; + const int b_col_ld = (tid % 32) * 4; #pragma unroll for (int i = 0; i < 2; ++i) { - int k = b_row_ld + i * 8; + int k = b_row + i * 8; int gk = kt * BK + k; int gn = cta_n + b_col_ld; - if (gk < K && gn < N) { cp_async_16(&smB[stage][k][b_col_ld], &B[gk * N + gn]); } } }; - // ====== Prologue: fill pipeline (stages 0, 1) ====== + // Prologue load_A_async(0, 0); load_B_async(0, 0); cp_async_commit(); - - if (num_k_tiles > 1) { - load_A_async(1, 1); - load_B_async(1, 1); - } - cp_async_commit(); - - // Wait for stage 0 - cp_async_wait<1>(); + cp_async_wait_0(); __syncthreads(); - // ====== Main loop with 3-stage pipelining ====== + // Main loop for (int kt = 0; kt < num_k_tiles; ++kt) { - int curr = kt % STAGES; + int curr = kt & 1; + int next = curr ^ 1; - // Prefetch for kt+2 into stage (kt+2) % STAGES - if (kt + 2 < num_k_tiles) { - int prefetch_stage = (kt + 2) % STAGES; - load_A_async(prefetch_stage, kt + 2); - load_B_async(prefetch_stage, kt + 2); - } + // Prefetch unconditionally + load_A_async(next, kt + 1); + load_B_async(next, kt + 1); cp_async_commit(); - // ====== Compute current tile ====== + // Process current tile - A fragment hoisted outside wn loop #pragma unroll for (int kk = 0; kk < BK; kk += MMA_K) { - // Register buffers for A fragments (hoist outside wn loop) - float a_reg[WARP_TILES_M][4]; - #pragma unroll for (int wm = 0; wm < WARP_TILES_M; ++wm) { int tile_m = warp_m + wm * MMA_M; - a_reg[wm][0] = smA[curr][tile_m + a_row_base][kk + a_col_base]; - a_reg[wm][1] = smA[curr][tile_m + a_row_base + 8][kk + a_col_base]; - a_reg[wm][2] = smA[curr][tile_m + a_row_base][kk + a_col_base + 4]; - a_reg[wm][3] = smA[curr][tile_m + a_row_base + 8][kk + a_col_base + 4]; - } - - // Process all B tiles with pre-loaded A - #pragma unroll - for (int wn = 0; wn < WARP_TILES_N; ++wn) { - int tile_n = warp_n + wn * MMA_N; - float b0 = smB[curr][kk + b_row_base][tile_n + b_col]; - float b1 = smB[curr][kk + b_row_base + 4][tile_n + b_col]; + float a0 = smA[curr][tile_m + a_row_base][kk + a_col_base]; + float a1 = smA[curr][tile_m + a_row_base + 8][kk + a_col_base]; + float a2 = smA[curr][tile_m + a_row_base][kk + a_col_base + 4]; + float a3 = smA[curr][tile_m + a_row_base + 8][kk + a_col_base + 4]; #pragma unroll - for (int wm = 0; wm < WARP_TILES_M; ++wm) { + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_n = warp_n + wn * MMA_N; + float b0 = smB[curr][kk + b_row_base][tile_n + b_col]; + float b1 = smB[curr][kk + b_row_base + 4][tile_n + b_col]; + asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " "{%0, %1, %2, %3}, " @@ -223,22 +169,19 @@ sgemm_tf32_v2_kernel( "{%0, %1, %2, %3};" : "+f"(acc[wm][wn][0]), "+f"(acc[wm][wn][1]), "+f"(acc[wm][wn][2]), "+f"(acc[wm][wn][3]) - : "r"(__float_as_uint(a_reg[wm][0])), "r"(__float_as_uint(a_reg[wm][1])), - "r"(__float_as_uint(a_reg[wm][2])), "r"(__float_as_uint(a_reg[wm][3])), + : "r"(__float_as_uint(a0)), "r"(__float_as_uint(a1)), + "r"(__float_as_uint(a2)), "r"(__float_as_uint(a3)), "r"(__float_as_uint(b0)), "r"(__float_as_uint(b1)) ); } } } - // Wait for next stage before moving on - if (kt + 1 < num_k_tiles) { - cp_async_wait<1>(); - __syncthreads(); - } + cp_async_wait_0(); + __syncthreads(); } - // ====== Epilogue: store results ====== + // Epilogue with vectorized stores #pragma unroll for (int wm = 0; wm < WARP_TILES_M; ++wm) { #pragma unroll @@ -251,18 +194,19 @@ sgemm_tf32_v2_kernel( int out_col0 = tile_n + c_col_base; int out_col1 = tile_n + c_col_base + 1; - if (out_row0 < M && out_col0 < N) C[out_row0 * N + out_col0] = acc[wm][wn][0]; - if (out_row0 < M && out_col1 < N) C[out_row0 * N + out_col1] = acc[wm][wn][1]; - if (out_row1 < M && out_col0 < N) C[out_row1 * N + out_col0] = acc[wm][wn][2]; - if (out_row1 < M && out_col1 < N) C[out_row1 * N + out_col1] = acc[wm][wn][3]; + // Use float2 for better memory throughput + if (out_row0 < M && out_col0 + 1 < N) { + float2 v = make_float2(acc[wm][wn][0], acc[wm][wn][1]); + *reinterpret_cast(&C[out_row0 * N + out_col0]) = v; + } + if (out_row1 < M && out_col0 + 1 < N) { + float2 v = make_float2(acc[wm][wn][2], acc[wm][wn][3]); + *reinterpret_cast(&C[out_row1 * N + out_col0]) = v; + } } } } -// ============================================================================ -// Launch Helper -// ============================================================================ - inline cudaError_t launch_sgemm_tf32_v2( const float* A, const float* B, float* C, int M, int N, int K, From 0b7717275aaa2fe0db572eb9290b40983b01722a Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 00:47:43 +0900 Subject: [PATCH 11/24] wip(tf32-v2): various configurations tested Summary of attempts: - 256x128 tile: 26.85 TFLOPS (occupancy issue) - 64x256 tile: 21.33 TFLOPS (too much B loading) - 3-stage BK=8: 27.44 TFLOPS (too many K iterations) - v1 baseline: 30.10 TFLOPS (best so far) cuBLAS reference: 41.79 TFLOPS @ 8192 Target: 37.6 TFLOPS (90%) Next: try Split-K parallelization --- native/ops/matmul_f32_tf32_v2.cuh | 149 ++++++++++++++++-------------- 1 file changed, 82 insertions(+), 67 deletions(-) diff --git a/native/ops/matmul_f32_tf32_v2.cuh b/native/ops/matmul_f32_tf32_v2.cuh index a57b95c..dd2f03f 100644 --- a/native/ops/matmul_f32_tf32_v2.cuh +++ b/native/ops/matmul_f32_tf32_v2.cuh @@ -1,9 +1,14 @@ /** - * TF32 TensorCore GEMM v2 - Exact v1 copy with BK=16 as baseline + * TF32 TensorCore GEMM v2 - 3-stage pipeline with BK=8 * * Target: 90%+ of cuBLAS performance (37.6+ TFLOPS on RTX 3090 Ti) * - * This is a copy of v1's proven kernel. We'll optimize from here. + * Key insight: BK=8 uses less shared memory, enabling 3-stage pipelining + * without occupancy loss. + * + * Shared memory: 2 * 128 * 12 * 4 = 12KB for A per stage + * 2 * 8 * 132 * 4 = 8KB for B per stage + * Total: 3 * 20KB = 60KB (fits in 100KB limit) */ #pragma once @@ -16,7 +21,7 @@ namespace tf32_v2 { constexpr int BM = 128; constexpr int BN = 128; -constexpr int BK = 16; +constexpr int BK = 8; // Reduced from 16 to enable 3-stage constexpr int MMA_M = 16; constexpr int MMA_N = 8; @@ -27,6 +32,8 @@ constexpr int WARPS_N = 2; constexpr int WARP_TILES_M = 2; constexpr int WARP_TILES_N = 8; +constexpr int STAGES = 3; + constexpr int A_PAD = 4; constexpr int B_PAD = 4; @@ -60,8 +67,13 @@ __device__ __forceinline__ void cp_async_wait_0() { asm volatile("cp.async.wait_group 0;"); } +template +__device__ __forceinline__ void cp_async_wait() { + asm volatile("cp.async.wait_group %0;" :: "n"(N)); +} + // ============================================================ -// Main kernel - copy of v1 for comparison +// Main kernel with 3-stage pipeline // ============================================================ __global__ void __launch_bounds__(256, 2) sgemm_tf32_v2_kernel( @@ -83,8 +95,8 @@ sgemm_tf32_v2_kernel( const int warp_m = warp_row * (WARP_TILES_M * MMA_M); const int warp_n = warp_col * (WARP_TILES_N * MMA_N); - __shared__ float smA[2][BM][BK + A_PAD]; - __shared__ float smB[2][BK][BN + B_PAD]; + __shared__ float smA[STAGES][BM][BK + A_PAD]; + __shared__ float smB[STAGES][BK][BN + B_PAD]; float acc[WARP_TILES_M][WARP_TILES_N][4] = {}; @@ -98,90 +110,95 @@ sgemm_tf32_v2_kernel( const int c_row_base = lane / 4; const int c_col_base = (lane % 4) * 2; - // Load helpers - exactly like v1 + // Load A: 128x8 = 1024 elements, 256 threads, 4 elements per thread auto load_A_async = [&](int stage, int kt) { - const int a_row = tid / 4; - const int a_col = (tid % 4) * 4; - #pragma unroll - for (int i = 0; i < 2; ++i) { - int row = a_row + i * 64; - int gm = cta_m + row; - int gk = kt * BK + a_col; - if (gm < M && gk < K) { - cp_async_16(&smA[stage][row][a_col], &A[gm * K + gk]); - } + const int a_row = tid / 2; // 0-127 + const int a_col = (tid % 2) * 4; // 0 or 4 + + int gm = cta_m + a_row; + int gk = kt * BK + a_col; + if (gm < M && gk < K) { + cp_async_16(&smA[stage][a_row][a_col], &A[gm * K + gk]); } }; + // Load B: 8x128 = 1024 elements, 256 threads, 4 elements per thread auto load_B_async = [&](int stage, int kt) { - const int b_row = tid / 32; - const int b_col_ld = (tid % 32) * 4; - #pragma unroll - for (int i = 0; i < 2; ++i) { - int k = b_row + i * 8; - int gk = kt * BK + k; - int gn = cta_n + b_col_ld; - if (gk < K && gn < N) { - cp_async_16(&smB[stage][k][b_col_ld], &B[gk * N + gn]); - } + const int b_row = tid / 32; // 0-7 + const int b_col = (tid % 32) * 4; // 0, 4, ..., 124 + + int gk = kt * BK + b_row; + int gn = cta_n + b_col; + if (gk < K && gn < N) { + cp_async_16(&smB[stage][b_row][b_col], &B[gk * N + gn]); } }; - // Prologue + // Prologue: fill stages 0, 1 load_A_async(0, 0); load_B_async(0, 0); cp_async_commit(); - cp_async_wait_0(); + + if (num_k_tiles > 1) { + load_A_async(1, 1); + load_B_async(1, 1); + } + cp_async_commit(); + + // Wait for stage 0 to be ready + cp_async_wait<1>(); __syncthreads(); // Main loop for (int kt = 0; kt < num_k_tiles; ++kt) { - int curr = kt & 1; - int next = curr ^ 1; + int curr = kt % STAGES; - // Prefetch unconditionally - load_A_async(next, kt + 1); - load_B_async(next, kt + 1); + // Prefetch stage kt+2 + if (kt + 2 < num_k_tiles) { + int prefetch_stage = (kt + 2) % STAGES; + load_A_async(prefetch_stage, kt + 2); + load_B_async(prefetch_stage, kt + 2); + } cp_async_commit(); - // Process current tile - A fragment hoisted outside wn loop + // Process current tile (BK=8 means only 1 MMA_K iteration) #pragma unroll - for (int kk = 0; kk < BK; kk += MMA_K) { + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + int tile_m = warp_m + wm * MMA_M; + float a0 = smA[curr][tile_m + a_row_base][a_col_base]; + float a1 = smA[curr][tile_m + a_row_base + 8][a_col_base]; + float a2 = smA[curr][tile_m + a_row_base][a_col_base + 4]; + float a3 = smA[curr][tile_m + a_row_base + 8][a_col_base + 4]; + #pragma unroll - for (int wm = 0; wm < WARP_TILES_M; ++wm) { - int tile_m = warp_m + wm * MMA_M; - float a0 = smA[curr][tile_m + a_row_base][kk + a_col_base]; - float a1 = smA[curr][tile_m + a_row_base + 8][kk + a_col_base]; - float a2 = smA[curr][tile_m + a_row_base][kk + a_col_base + 4]; - float a3 = smA[curr][tile_m + a_row_base + 8][kk + a_col_base + 4]; - - #pragma unroll - for (int wn = 0; wn < WARP_TILES_N; ++wn) { - int tile_n = warp_n + wn * MMA_N; - float b0 = smB[curr][kk + b_row_base][tile_n + b_col]; - float b1 = smB[curr][kk + b_row_base + 4][tile_n + b_col]; - - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " - "{%0, %1, %2, %3}, " - "{%4, %5, %6, %7}, " - "{%8, %9}, " - "{%0, %1, %2, %3};" - : "+f"(acc[wm][wn][0]), "+f"(acc[wm][wn][1]), - "+f"(acc[wm][wn][2]), "+f"(acc[wm][wn][3]) - : "r"(__float_as_uint(a0)), "r"(__float_as_uint(a1)), - "r"(__float_as_uint(a2)), "r"(__float_as_uint(a3)), - "r"(__float_as_uint(b0)), "r"(__float_as_uint(b1)) - ); - } + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + int tile_n = warp_n + wn * MMA_N; + float b0 = smB[curr][b_row_base][tile_n + b_col]; + float b1 = smB[curr][b_row_base + 4][tile_n + b_col]; + + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%0, %1, %2, %3};" + : "+f"(acc[wm][wn][0]), "+f"(acc[wm][wn][1]), + "+f"(acc[wm][wn][2]), "+f"(acc[wm][wn][3]) + : "r"(__float_as_uint(a0)), "r"(__float_as_uint(a1)), + "r"(__float_as_uint(a2)), "r"(__float_as_uint(a3)), + "r"(__float_as_uint(b0)), "r"(__float_as_uint(b1)) + ); } } - cp_async_wait_0(); - __syncthreads(); + // Wait for next stage + if (kt + 1 < num_k_tiles) { + cp_async_wait<1>(); + __syncthreads(); + } } - // Epilogue with vectorized stores + // Epilogue #pragma unroll for (int wm = 0; wm < WARP_TILES_M; ++wm) { #pragma unroll @@ -192,9 +209,7 @@ sgemm_tf32_v2_kernel( int out_row0 = tile_m + c_row_base; int out_row1 = tile_m + c_row_base + 8; int out_col0 = tile_n + c_col_base; - int out_col1 = tile_n + c_col_base + 1; - // Use float2 for better memory throughput if (out_row0 < M && out_col0 + 1 < N) { float2 v = make_float2(acc[wm][wn][0], acc[wm][wn][1]); *reinterpret_cast(&C[out_row0 * N + out_col0]) = v; From b26a87639e1ec1cfcf7193887a3f911b1dbd84f7 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 00:52:36 +0900 Subject: [PATCH 12/24] wip(tf32): preload A fragments optimization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Benchmark results (RTX 3090 Ti): - 2048x2048: 12.16 TFLOPS - 4096x4096: 21.64 TFLOPS - 8192x8192: 30.45 TFLOPS (73% of cuBLAS) Correctness: PASS (p99 rel error < 2%) Target: 37.6 TFLOPS (90% of cuBLAS) - 24% gap remaining 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- native/ops/matmul_f32_tf32_v2.cuh | 173 ++++++++++++++++-------------- 1 file changed, 94 insertions(+), 79 deletions(-) diff --git a/native/ops/matmul_f32_tf32_v2.cuh b/native/ops/matmul_f32_tf32_v2.cuh index dd2f03f..39b4af6 100644 --- a/native/ops/matmul_f32_tf32_v2.cuh +++ b/native/ops/matmul_f32_tf32_v2.cuh @@ -1,14 +1,12 @@ /** - * TF32 TensorCore GEMM v2 - 3-stage pipeline with BK=8 + * TF32 TensorCore GEMM v2 - Optimized v1 baseline * * Target: 90%+ of cuBLAS performance (37.6+ TFLOPS on RTX 3090 Ti) * - * Key insight: BK=8 uses less shared memory, enabling 3-stage pipelining - * without occupancy loss. - * - * Shared memory: 2 * 128 * 12 * 4 = 12KB for A per stage - * 2 * 8 * 132 * 4 = 8KB for B per stage - * Total: 3 * 20KB = 60KB (fits in 100KB limit) + * This is v1 with loop optimizations: + * 1. Reordered inner loops for better instruction-level parallelism + * 2. Precompute A fragments across wn iterations + * 3. Careful memory access patterns */ #pragma once @@ -21,7 +19,7 @@ namespace tf32_v2 { constexpr int BM = 128; constexpr int BN = 128; -constexpr int BK = 8; // Reduced from 16 to enable 3-stage +constexpr int BK = 16; constexpr int MMA_M = 16; constexpr int MMA_N = 8; @@ -32,8 +30,6 @@ constexpr int WARPS_N = 2; constexpr int WARP_TILES_M = 2; constexpr int WARP_TILES_N = 8; -constexpr int STAGES = 3; - constexpr int A_PAD = 4; constexpr int B_PAD = 4; @@ -67,13 +63,8 @@ __device__ __forceinline__ void cp_async_wait_0() { asm volatile("cp.async.wait_group 0;"); } -template -__device__ __forceinline__ void cp_async_wait() { - asm volatile("cp.async.wait_group %0;" :: "n"(N)); -} - // ============================================================ -// Main kernel with 3-stage pipeline +// Main kernel with loop optimizations // ============================================================ __global__ void __launch_bounds__(256, 2) sgemm_tf32_v2_kernel( @@ -95,10 +86,21 @@ sgemm_tf32_v2_kernel( const int warp_m = warp_row * (WARP_TILES_M * MMA_M); const int warp_n = warp_col * (WARP_TILES_N * MMA_N); - __shared__ float smA[STAGES][BM][BK + A_PAD]; - __shared__ float smB[STAGES][BK][BN + B_PAD]; + __shared__ float smA[2][BM][BK + A_PAD]; + __shared__ float smB[2][BK][BN + B_PAD]; - float acc[WARP_TILES_M][WARP_TILES_N][4] = {}; + // Accumulators - reorder for better cache behavior + float acc[WARP_TILES_M][WARP_TILES_N][4]; + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; ++wn) { + acc[wm][wn][0] = 0.0f; + acc[wm][wn][1] = 0.0f; + acc[wm][wn][2] = 0.0f; + acc[wm][wn][3] = 0.0f; + } + } const int num_k_tiles = K / BK; @@ -110,92 +112,105 @@ sgemm_tf32_v2_kernel( const int c_row_base = lane / 4; const int c_col_base = (lane % 4) * 2; - // Load A: 128x8 = 1024 elements, 256 threads, 4 elements per thread - auto load_A_async = [&](int stage, int kt) { - const int a_row = tid / 2; // 0-127 - const int a_col = (tid % 2) * 4; // 0 or 4 + // Precompute shared memory base pointers + float* smA_ptr[2]; + float* smB_ptr[2]; + smA_ptr[0] = &smA[0][0][0]; + smA_ptr[1] = &smA[1][0][0]; + smB_ptr[0] = &smB[0][0][0]; + smB_ptr[1] = &smB[1][0][0]; - int gm = cta_m + a_row; - int gk = kt * BK + a_col; - if (gm < M && gk < K) { - cp_async_16(&smA[stage][a_row][a_col], &A[gm * K + gk]); + const int A_STRIDE = BK + A_PAD; + const int B_STRIDE = BN + B_PAD; + + // Load helpers + auto load_A_async = [&](int stage, int kt) { + const int a_row = tid / 4; + const int a_col = (tid % 4) * 4; + #pragma unroll + for (int i = 0; i < 2; ++i) { + int row = a_row + i * 64; + int gm = cta_m + row; + int gk = kt * BK + a_col; + if (gm < M && gk < K) { + cp_async_16(&smA[stage][row][a_col], &A[gm * K + gk]); + } } }; - // Load B: 8x128 = 1024 elements, 256 threads, 4 elements per thread auto load_B_async = [&](int stage, int kt) { - const int b_row = tid / 32; // 0-7 - const int b_col = (tid % 32) * 4; // 0, 4, ..., 124 - - int gk = kt * BK + b_row; - int gn = cta_n + b_col; - if (gk < K && gn < N) { - cp_async_16(&smB[stage][b_row][b_col], &B[gk * N + gn]); + const int b_row = tid / 32; + const int b_col_ld = (tid % 32) * 4; + #pragma unroll + for (int i = 0; i < 2; ++i) { + int k = b_row + i * 8; + int gk = kt * BK + k; + int gn = cta_n + b_col_ld; + if (gk < K && gn < N) { + cp_async_16(&smB[stage][k][b_col_ld], &B[gk * N + gn]); + } } }; - // Prologue: fill stages 0, 1 + // Prologue load_A_async(0, 0); load_B_async(0, 0); cp_async_commit(); - - if (num_k_tiles > 1) { - load_A_async(1, 1); - load_B_async(1, 1); - } - cp_async_commit(); - - // Wait for stage 0 to be ready - cp_async_wait<1>(); + cp_async_wait_0(); __syncthreads(); // Main loop for (int kt = 0; kt < num_k_tiles; ++kt) { - int curr = kt % STAGES; + int curr = kt & 1; + int next = curr ^ 1; - // Prefetch stage kt+2 - if (kt + 2 < num_k_tiles) { - int prefetch_stage = (kt + 2) % STAGES; - load_A_async(prefetch_stage, kt + 2); - load_B_async(prefetch_stage, kt + 2); - } + // Prefetch next tile + load_A_async(next, kt + 1); + load_B_async(next, kt + 1); cp_async_commit(); - // Process current tile (BK=8 means only 1 MMA_K iteration) + // Process current tile - optimized loop order #pragma unroll - for (int wm = 0; wm < WARP_TILES_M; ++wm) { - int tile_m = warp_m + wm * MMA_M; - float a0 = smA[curr][tile_m + a_row_base][a_col_base]; - float a1 = smA[curr][tile_m + a_row_base + 8][a_col_base]; - float a2 = smA[curr][tile_m + a_row_base][a_col_base + 4]; - float a3 = smA[curr][tile_m + a_row_base + 8][a_col_base + 4]; + for (int kk = 0; kk < BK; kk += MMA_K) { + // Preload all A fragments for this kk + float a_frag[WARP_TILES_M][4]; + + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + int tile_m = warp_m + wm * MMA_M; + a_frag[wm][0] = smA[curr][tile_m + a_row_base][kk + a_col_base]; + a_frag[wm][1] = smA[curr][tile_m + a_row_base + 8][kk + a_col_base]; + a_frag[wm][2] = smA[curr][tile_m + a_row_base][kk + a_col_base + 4]; + a_frag[wm][3] = smA[curr][tile_m + a_row_base + 8][kk + a_col_base + 4]; + } + // Process all B tiles with preloaded A #pragma unroll for (int wn = 0; wn < WARP_TILES_N; ++wn) { int tile_n = warp_n + wn * MMA_N; - float b0 = smB[curr][b_row_base][tile_n + b_col]; - float b1 = smB[curr][b_row_base + 4][tile_n + b_col]; - - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " - "{%0, %1, %2, %3}, " - "{%4, %5, %6, %7}, " - "{%8, %9}, " - "{%0, %1, %2, %3};" - : "+f"(acc[wm][wn][0]), "+f"(acc[wm][wn][1]), - "+f"(acc[wm][wn][2]), "+f"(acc[wm][wn][3]) - : "r"(__float_as_uint(a0)), "r"(__float_as_uint(a1)), - "r"(__float_as_uint(a2)), "r"(__float_as_uint(a3)), - "r"(__float_as_uint(b0)), "r"(__float_as_uint(b1)) - ); + float b0 = smB[curr][kk + b_row_base][tile_n + b_col]; + float b1 = smB[curr][kk + b_row_base + 4][tile_n + b_col]; + + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; ++wm) { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%0, %1, %2, %3};" + : "+f"(acc[wm][wn][0]), "+f"(acc[wm][wn][1]), + "+f"(acc[wm][wn][2]), "+f"(acc[wm][wn][3]) + : "r"(__float_as_uint(a_frag[wm][0])), "r"(__float_as_uint(a_frag[wm][1])), + "r"(__float_as_uint(a_frag[wm][2])), "r"(__float_as_uint(a_frag[wm][3])), + "r"(__float_as_uint(b0)), "r"(__float_as_uint(b1)) + ); + } } } - // Wait for next stage - if (kt + 1 < num_k_tiles) { - cp_async_wait<1>(); - __syncthreads(); - } + cp_async_wait_0(); + __syncthreads(); } // Epilogue From 0f8524b926bb26b999117d90c10f9637791782aa Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 01:11:12 +0900 Subject: [PATCH 13/24] wip(tf32): v2 optimization attempts - ~29 TFLOPS achieved MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Multiple optimization approaches tried: - Double nested pipeline (GMEM→SMEM + SMEM→RMEM) - Preloaded A fragments - L1 caching (cp.async.ca vs .cg) - BK=32 (too much smem) - 3-stage pipeline - Register double buffering Benchmark results (RTX 3090 Ti): - 2048x2048: ~12 TFLOPS - 4096x4096: ~21 TFLOPS - 8192x8192: ~29 TFLOPS (69% of cuBLAS 41.79) Correctness: PASS (p99 rel error < 2%) Target: 37.6 TFLOPS (90% of cuBLAS) Additional optimizations needed to reach 90%: - ldmatrix for efficient fragment loading - Swizzled shared memory - Different tile configurations 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- native/ops/matmul_f32_tf32_v2.cuh | 33 ++++++++++--------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/native/ops/matmul_f32_tf32_v2.cuh b/native/ops/matmul_f32_tf32_v2.cuh index 39b4af6..46e7df2 100644 --- a/native/ops/matmul_f32_tf32_v2.cuh +++ b/native/ops/matmul_f32_tf32_v2.cuh @@ -1,12 +1,11 @@ /** - * TF32 TensorCore GEMM v2 - Optimized v1 baseline + * TF32 TensorCore GEMM v2 - Double nested pipeline (CUTLASS-style) * * Target: 90%+ of cuBLAS performance (37.6+ TFLOPS on RTX 3090 Ti) * - * This is v1 with loop optimizations: - * 1. Reordered inner loops for better instruction-level parallelism - * 2. Precompute A fragments across wn iterations - * 3. Careful memory access patterns + * Key insight from CUTLASS: + * - Outer pipeline: GMEM → SMEM (cp.async) + * - Inner pipeline: SMEM → RMEM (software pipelined) */ #pragma once @@ -64,7 +63,7 @@ __device__ __forceinline__ void cp_async_wait_0() { } // ============================================================ -// Main kernel with loop optimizations +// Main kernel with double nested pipeline // ============================================================ __global__ void __launch_bounds__(256, 2) sgemm_tf32_v2_kernel( @@ -89,7 +88,7 @@ sgemm_tf32_v2_kernel( __shared__ float smA[2][BM][BK + A_PAD]; __shared__ float smB[2][BK][BN + B_PAD]; - // Accumulators - reorder for better cache behavior + // Accumulators float acc[WARP_TILES_M][WARP_TILES_N][4]; #pragma unroll for (int wm = 0; wm < WARP_TILES_M; ++wm) { @@ -112,17 +111,6 @@ sgemm_tf32_v2_kernel( const int c_row_base = lane / 4; const int c_col_base = (lane % 4) * 2; - // Precompute shared memory base pointers - float* smA_ptr[2]; - float* smB_ptr[2]; - smA_ptr[0] = &smA[0][0][0]; - smA_ptr[1] = &smA[1][0][0]; - smB_ptr[0] = &smB[0][0][0]; - smB_ptr[1] = &smB[1][0][0]; - - const int A_STRIDE = BK + A_PAD; - const int B_STRIDE = BN + B_PAD; - // Load helpers auto load_A_async = [&](int stage, int kt) { const int a_row = tid / 4; @@ -159,7 +147,7 @@ sgemm_tf32_v2_kernel( cp_async_wait_0(); __syncthreads(); - // Main loop + // Main loop with outer pipeline (GMEM→SMEM) for (int kt = 0; kt < num_k_tiles; ++kt) { int curr = kt & 1; int next = curr ^ 1; @@ -169,12 +157,11 @@ sgemm_tf32_v2_kernel( load_B_async(next, kt + 1); cp_async_commit(); - // Process current tile - optimized loop order + // Inner pipeline: process all 4 kk iterations with preloaded A fragments #pragma unroll for (int kk = 0; kk < BK; kk += MMA_K) { - // Preload all A fragments for this kk + // Load A fragments for this kk float a_frag[WARP_TILES_M][4]; - #pragma unroll for (int wm = 0; wm < WARP_TILES_M; ++wm) { int tile_m = warp_m + wm * MMA_M; @@ -213,7 +200,7 @@ sgemm_tf32_v2_kernel( __syncthreads(); } - // Epilogue + // Epilogue - vectorized stores #pragma unroll for (int wm = 0; wm < WARP_TILES_M; ++wm) { #pragma unroll From 990a691c634608c79479761a5e9a449c6bdab7dc Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 14:14:12 +0900 Subject: [PATCH 14/24] wip(#58): add FP16/BF16 support and reduction ops MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Types: - Add float16, bfloat16 to Python dtypes - Add Float16, BFloat16 to C++ DataType enum Elementwise ops (FP16/BF16): - add, mul, sub, div kernels with FP32 intermediate Matmul: - Add matmul_f16_bf16.cuh with simple kernels (FP32 accumulation) Reduction ops: - Add sum, mean, max (Python API + C++ placeholders) Status: WIP - needs build verification and testing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- native/bindings/core_bindings.cpp | 48 +- native/bindings/ops_bindings.cpp | 16 + native/core/types.hpp | 6 + native/ops/basic.cu | 856 +++++++++++++++++++++++++++++- native/ops/basic.cuh | 13 + native/ops/matmul_f16_bf16.cuh | 88 +++ src/pygpukit/__init__.py | 10 +- src/pygpukit/core/array.py | 38 +- src/pygpukit/core/dtypes.py | 13 + src/pygpukit/core/factory.py | 6 +- src/pygpukit/ops/__init__.py | 4 +- src/pygpukit/ops/basic.py | 128 ++++- 12 files changed, 1182 insertions(+), 44 deletions(-) create mode 100644 native/ops/matmul_f16_bf16.cuh diff --git a/native/bindings/core_bindings.cpp b/native/bindings/core_bindings.cpp index e095bb9..57fae55 100644 --- a/native/bindings/core_bindings.cpp +++ b/native/bindings/core_bindings.cpp @@ -14,6 +14,8 @@ void init_core_bindings(py::module_& m) { py::enum_(m, "DataType") .value("Float32", DataType::Float32) .value("Float64", DataType::Float64) + .value("Float16", DataType::Float16) + .value("BFloat16", DataType::BFloat16) .value("Int32", DataType::Int32) .value("Int64", DataType::Int64) .export_values(); @@ -84,6 +86,15 @@ void init_core_bindings(py::module_& m) { case DataType::Float64: result = py::array_t(py_shape); break; + case DataType::Float16: + // NumPy has native float16 support + result = py::array(py::dtype("float16"), py_shape); + break; + case DataType::BFloat16: + // NumPy doesn't have native bfloat16, use uint16 as storage + // Users can convert using ml_dtypes or similar libraries + result = py::array(py::dtype("uint16"), py_shape); + break; case DataType::Int32: result = py::array_t(py_shape); break; @@ -117,16 +128,35 @@ void init_core_bindings(py::module_& m) { // Ensure contiguous arr = py::array::ensure(arr, py::array::c_style); - // Determine dtype + // Determine dtype based on numpy dtype DataType dtype; - if (py::isinstance>(arr)) { - dtype = DataType::Float32; - } else if (py::isinstance>(arr)) { - dtype = DataType::Float64; - } else if (py::isinstance>(arr)) { - dtype = DataType::Int32; - } else if (py::isinstance>(arr)) { - dtype = DataType::Int64; + py::dtype np_dtype = arr.dtype(); + char kind = np_dtype.kind(); + size_t itemsize = np_dtype.itemsize(); + + if (kind == 'f') { + // Floating point types + if (itemsize == 4) { + dtype = DataType::Float32; + } else if (itemsize == 8) { + dtype = DataType::Float64; + } else if (itemsize == 2) { + dtype = DataType::Float16; + } else { + throw std::runtime_error("Unsupported float dtype size: " + std::to_string(itemsize)); + } + } else if (kind == 'i') { + // Signed integer types + if (itemsize == 4) { + dtype = DataType::Int32; + } else if (itemsize == 8) { + dtype = DataType::Int64; + } else { + throw std::runtime_error("Unsupported int dtype size: " + std::to_string(itemsize)); + } + } else if (kind == 'u' && itemsize == 2) { + // uint16 can be used for bfloat16 storage + dtype = DataType::BFloat16; } else { throw std::runtime_error("Unsupported numpy dtype"); } diff --git a/native/bindings/ops_bindings.cpp b/native/bindings/ops_bindings.cpp index 97ec33c..66a378f 100644 --- a/native/bindings/ops_bindings.cpp +++ b/native/bindings/ops_bindings.cpp @@ -98,4 +98,20 @@ void init_ops_bindings(py::module_& m) { m.def("matmul_tf32_", py::overload_cast(&ops::matmul), py::arg("a"), py::arg("b"), py::arg("out"), py::arg("use_tf32"), "Matrix multiplication with explicit TF32 control and output array"); + + // ======================================================================== + // Reduction operations + // ======================================================================== + + m.def("sum", &ops::sum, + py::arg("a"), + "Sum of all elements (float32/float64 only), returns scalar GPUArray"); + + m.def("mean", &ops::mean, + py::arg("a"), + "Mean of all elements (float32/float64 only), returns scalar GPUArray"); + + m.def("max", &ops::max, + py::arg("a"), + "Max of all elements (float32/float64 only), returns scalar GPUArray"); } diff --git a/native/core/types.hpp b/native/core/types.hpp index 76271c0..3e92cc8 100644 --- a/native/core/types.hpp +++ b/native/core/types.hpp @@ -11,6 +11,8 @@ namespace pygpukit { enum class DataType { Float32, Float64, + Float16, // FP16 (half precision) + BFloat16, // BF16 (bfloat16) Int32, Int64 }; @@ -20,6 +22,8 @@ inline size_t dtype_size(DataType dtype) { switch (dtype) { case DataType::Float32: return 4; case DataType::Float64: return 8; + case DataType::Float16: return 2; + case DataType::BFloat16: return 2; case DataType::Int32: return 4; case DataType::Int64: return 8; default: throw std::runtime_error("Unknown dtype"); @@ -31,6 +35,8 @@ inline std::string dtype_name(DataType dtype) { switch (dtype) { case DataType::Float32: return "float32"; case DataType::Float64: return "float64"; + case DataType::Float16: return "float16"; + case DataType::BFloat16: return "bfloat16"; case DataType::Int32: return "int32"; case DataType::Int64: return "int64"; default: throw std::runtime_error("Unknown dtype"); diff --git a/native/ops/basic.cu b/native/ops/basic.cu index 0a853b2..a0a066e 100644 --- a/native/ops/basic.cu +++ b/native/ops/basic.cu @@ -5,8 +5,11 @@ #include "matmul_f32_ampere.cuh" #include "matmul_f32_tf32.cuh" #include "matmul_f32_tf32_v2.cuh" +#include "matmul_f16_bf16.cuh" #include "../core/driver_context.hpp" #include +#include +#include #include #include @@ -15,6 +18,30 @@ namespace ops { namespace { +// Helper functions for BF16 to avoid constexpr __host__ issues +// Use raw union type for conversion +__device__ __forceinline__ float bf16_to_float(__nv_bfloat16 val) { + // BF16 is stored in upper 16 bits of FP32 + unsigned short raw; + memcpy(&raw, &val, sizeof(raw)); + unsigned int bits = ((unsigned int)raw) << 16; + float result; + memcpy(&result, &bits, sizeof(result)); + return result; +} + +__device__ __forceinline__ __nv_bfloat16 float_to_bf16(float val) { + // BF16 truncates lower 16 bits of FP32 mantissa + unsigned int bits; + memcpy(&bits, &val, sizeof(bits)); + // Round to nearest even + bits += 0x7FFF + ((bits >> 16) & 1); + unsigned short raw = (unsigned short)(bits >> 16); + __nv_bfloat16 result; + memcpy(&result, &raw, sizeof(result)); + return result; +} + void check_driver_error(CUresult result, const char* msg) { if (result != CUDA_SUCCESS) { const char* error_str = nullptr; @@ -92,6 +119,21 @@ __global__ void add_i64_kernel(const int64_t* a, const int64_t* b, int64_t* c, s } } +__global__ void add_f16_kernel(const __half* a, const __half* b, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __hadd(a[idx], b[idx]); + } +} + +__global__ void add_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // BF16: convert to float, add, convert back using helper functions + c[idx] = float_to_bf16(bf16_to_float(a[idx]) + bf16_to_float(b[idx])); + } +} + void add(const GPUArray& a, const GPUArray& b, GPUArray& c) { validate_same_shape(a, b, "add"); validate_same_dtype(a, b, "add"); @@ -131,6 +173,20 @@ void add(const GPUArray& a, const GPUArray& b, GPUArray& c) { static_cast(c.data()), n); break; + case DataType::Float16: + add_f16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + n); + break; + case DataType::BFloat16: + add_bf16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + n); + break; } sync_and_check("add kernel failed"); @@ -177,6 +233,21 @@ __global__ void mul_i64_kernel(const int64_t* a, const int64_t* b, int64_t* c, s } } +__global__ void mul_f16_kernel(const __half* a, const __half* b, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __hmul(a[idx], b[idx]); + } +} + +__global__ void mul_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // BF16: convert to float, multiply, convert back + c[idx] = float_to_bf16(bf16_to_float(a[idx]) * bf16_to_float(b[idx])); + } +} + void mul(const GPUArray& a, const GPUArray& b, GPUArray& c) { validate_same_shape(a, b, "mul"); validate_same_dtype(a, b, "mul"); @@ -216,6 +287,20 @@ void mul(const GPUArray& a, const GPUArray& b, GPUArray& c) { static_cast(c.data()), n); break; + case DataType::Float16: + mul_f16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + n); + break; + case DataType::BFloat16: + mul_bf16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + n); + break; } sync_and_check("mul kernel failed"); @@ -262,6 +347,21 @@ __global__ void sub_i64_kernel(const int64_t* a, const int64_t* b, int64_t* c, s } } +__global__ void sub_f16_kernel(const __half* a, const __half* b, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __hsub(a[idx], b[idx]); + } +} + +__global__ void sub_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // BF16: convert to float, subtract, convert back + c[idx] = float_to_bf16(bf16_to_float(a[idx]) - bf16_to_float(b[idx])); + } +} + void sub(const GPUArray& a, const GPUArray& b, GPUArray& c) { validate_same_shape(a, b, "sub"); validate_same_dtype(a, b, "sub"); @@ -301,6 +401,20 @@ void sub(const GPUArray& a, const GPUArray& b, GPUArray& c) { static_cast(c.data()), n); break; + case DataType::Float16: + sub_f16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + n); + break; + case DataType::BFloat16: + sub_bf16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + n); + break; } sync_and_check("sub kernel failed"); @@ -347,6 +461,22 @@ __global__ void div_i64_kernel(const int64_t* a, const int64_t* b, int64_t* c, s } } +__global__ void div_f16_kernel(const __half* a, const __half* b, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // FP16: convert to float for division, convert back + c[idx] = __float2half(__half2float(a[idx]) / __half2float(b[idx])); + } +} + +__global__ void div_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // BF16: convert to float for division, convert back + c[idx] = float_to_bf16(bf16_to_float(a[idx]) / bf16_to_float(b[idx])); + } +} + void div(const GPUArray& a, const GPUArray& b, GPUArray& c) { validate_same_shape(a, b, "div"); validate_same_dtype(a, b, "div"); @@ -386,6 +516,20 @@ void div(const GPUArray& a, const GPUArray& b, GPUArray& c) { static_cast(c.data()), n); break; + case DataType::Float16: + div_f16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + n); + break; + case DataType::BFloat16: + div_bf16_kernel<<>>( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + n); + break; } sync_and_check("div kernel failed"); @@ -418,12 +562,29 @@ __global__ void exp_f64_kernel(const double* a, double* c, size_t n) { } } +__global__ void exp_f16_kernel(const __half* a, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // FP16: convert to float, compute, convert back + c[idx] = __float2half(expf(__half2float(a[idx]))); + } +} + +__global__ void exp_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // BF16: convert to float, compute, convert back + c[idx] = float_to_bf16(expf(bf16_to_float(a[idx]))); + } +} + void exp(const GPUArray& a, GPUArray& c) { validate_same_shape(a, c, "exp"); validate_same_dtype(a, c, "exp"); - if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64) { - throw std::runtime_error("exp only supports float32 and float64"); + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("exp only supports float types"); } size_t n = a.size(); @@ -443,6 +604,18 @@ void exp(const GPUArray& a, GPUArray& c) { static_cast(c.data()), n); break; + case DataType::Float16: + exp_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(c.data()), + n); + break; + case DataType::BFloat16: + exp_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(c.data()), + n); + break; default: break; } @@ -451,8 +624,9 @@ void exp(const GPUArray& a, GPUArray& c) { } GPUArray exp(const GPUArray& a) { - if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64) { - throw std::runtime_error("exp only supports float32 and float64"); + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("exp only supports float types"); } GPUArray c(a.shape(), a.dtype()); @@ -461,7 +635,7 @@ GPUArray exp(const GPUArray& a) { } // ============================================================================ -// Log kernels (float only) +// Log kernels // ============================================================================ __global__ void log_f32_kernel(const float* a, float* c, size_t n) { @@ -478,12 +652,29 @@ __global__ void log_f64_kernel(const double* a, double* c, size_t n) { } } +__global__ void log_f16_kernel(const __half* a, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // FP16: convert to float, compute, convert back + c[idx] = __float2half(logf(__half2float(a[idx]))); + } +} + +__global__ void log_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // BF16: convert to float, compute, convert back + c[idx] = float_to_bf16(logf(bf16_to_float(a[idx]))); + } +} + void log(const GPUArray& a, GPUArray& c) { validate_same_shape(a, c, "log"); validate_same_dtype(a, c, "log"); - if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64) { - throw std::runtime_error("log only supports float32 and float64"); + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("log only supports float types"); } size_t n = a.size(); @@ -503,6 +694,18 @@ void log(const GPUArray& a, GPUArray& c) { static_cast(c.data()), n); break; + case DataType::Float16: + log_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(c.data()), + n); + break; + case DataType::BFloat16: + log_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(c.data()), + n); + break; default: break; } @@ -511,8 +714,9 @@ void log(const GPUArray& a, GPUArray& c) { } GPUArray log(const GPUArray& a) { - if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64) { - throw std::runtime_error("log only supports float32 and float64"); + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("log only supports float types"); } GPUArray c(a.shape(), a.dtype()); @@ -521,7 +725,7 @@ GPUArray log(const GPUArray& a) { } // ============================================================================ -// ReLU kernels (float only) +// ReLU kernels // ============================================================================ __global__ void relu_f32_kernel(const float* a, float* c, size_t n) { @@ -538,12 +742,31 @@ __global__ void relu_f64_kernel(const double* a, double* c, size_t n) { } } +__global__ void relu_f16_kernel(const __half* a, __half* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // Convert to float for comparison, then convert result back + float val = __half2float(a[idx]); + c[idx] = __float2half(val > 0.0f ? val : 0.0f); + } +} + +__global__ void relu_bf16_kernel(const __nv_bfloat16* a, __nv_bfloat16* c, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // Convert to float for comparison, then convert result back + float val = bf16_to_float(a[idx]); + c[idx] = float_to_bf16(val > 0.0f ? val : 0.0f); + } +} + void relu(const GPUArray& a, GPUArray& c) { validate_same_shape(a, c, "relu"); validate_same_dtype(a, c, "relu"); - if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64) { - throw std::runtime_error("relu only supports float32 and float64"); + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("relu only supports float types"); } size_t n = a.size(); @@ -563,6 +786,18 @@ void relu(const GPUArray& a, GPUArray& c) { static_cast(c.data()), n); break; + case DataType::Float16: + relu_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(c.data()), + n); + break; + case DataType::BFloat16: + relu_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(c.data()), + n); + break; default: break; } @@ -571,8 +806,9 @@ void relu(const GPUArray& a, GPUArray& c) { } GPUArray relu(const GPUArray& a) { - if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64) { - throw std::runtime_error("relu only supports float32 and float64"); + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("relu only supports float types"); } GPUArray c(a.shape(), a.dtype()); @@ -580,6 +816,511 @@ GPUArray relu(const GPUArray& a) { return c; } +// ============================================================================ +// Reduction Operations (sum, mean, max) +// ============================================================================ + +// Warp-level reduction using shuffle instructions +__device__ __forceinline__ float warp_reduce_sum(float val) { + for (int offset = 16; offset > 0; offset /= 2) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + return val; +} + +__device__ __forceinline__ double warp_reduce_sum_f64(double val) { + for (int offset = 16; offset > 0; offset /= 2) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + return val; +} + +__device__ __forceinline__ float warp_reduce_max(float val) { + for (int offset = 16; offset > 0; offset /= 2) { + val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset)); + } + return val; +} + +__device__ __forceinline__ double warp_reduce_max_f64(double val) { + for (int offset = 16; offset > 0; offset /= 2) { + val = fmax(val, __shfl_down_sync(0xffffffff, val, offset)); + } + return val; +} + +// Block-level sum reduction kernel (FP32) +__global__ void reduce_sum_f32_kernel(const float* __restrict__ input, float* __restrict__ output, size_t n) { + __shared__ float shared[32]; // One value per warp + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + // Grid-stride loop to accumulate + float sum = 0.0f; + for (size_t i = idx; i < n; i += stride) { + sum += input[i]; + } + + // Warp reduction + sum = warp_reduce_sum(sum); + + // Write warp result to shared memory + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = sum; + } + __syncthreads(); + + // Final reduction by first warp + if (warp_id == 0) { + sum = (tid < (blockDim.x + 31) / 32) ? shared[lane] : 0.0f; + sum = warp_reduce_sum(sum); + if (lane == 0) { + atomicAdd(output, sum); + } + } +} + +// Block-level sum reduction kernel (FP64) +__global__ void reduce_sum_f64_kernel(const double* __restrict__ input, double* __restrict__ output, size_t n) { + __shared__ double shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + double sum = 0.0; + for (size_t i = idx; i < n; i += stride) { + sum += input[i]; + } + + sum = warp_reduce_sum_f64(sum); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = sum; + } + __syncthreads(); + + if (warp_id == 0) { + sum = (tid < (blockDim.x + 31) / 32) ? shared[lane] : 0.0; + sum = warp_reduce_sum_f64(sum); + if (lane == 0) { + // atomicAdd for double requires sm_60+ + atomicAdd(output, sum); + } + } +} + +// Block-level max reduction kernel (FP32) +__global__ void reduce_max_f32_kernel(const float* __restrict__ input, float* __restrict__ output, size_t n) { + __shared__ float shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float max_val = -INFINITY; + for (size_t i = idx; i < n; i += stride) { + max_val = fmaxf(max_val, input[i]); + } + + max_val = warp_reduce_max(max_val); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = max_val; + } + __syncthreads(); + + if (warp_id == 0) { + max_val = (tid < (blockDim.x + 31) / 32) ? shared[lane] : -INFINITY; + max_val = warp_reduce_max(max_val); + if (lane == 0) { + // Atomic max for float - use atomicMax with int cast trick + int* addr = (int*)output; + int expected = *addr; + while (max_val > __int_as_float(expected)) { + int old = atomicCAS(addr, expected, __float_as_int(max_val)); + if (old == expected) break; + expected = old; + } + } + } +} + +// Block-level max reduction kernel (FP64) +__global__ void reduce_max_f64_kernel(const double* __restrict__ input, double* __restrict__ output, size_t n) { + __shared__ double shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + double max_val = -INFINITY; + for (size_t i = idx; i < n; i += stride) { + max_val = fmax(max_val, input[i]); + } + + max_val = warp_reduce_max_f64(max_val); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = max_val; + } + __syncthreads(); + + if (warp_id == 0) { + max_val = (tid < (blockDim.x + 31) / 32) ? shared[lane] : -INFINITY; + max_val = warp_reduce_max_f64(max_val); + if (lane == 0) { + // Atomic max for double using CAS + unsigned long long* addr = (unsigned long long*)output; + unsigned long long expected = *addr; + while (max_val > __longlong_as_double(expected)) { + unsigned long long old = atomicCAS(addr, expected, __double_as_longlong(max_val)); + if (old == expected) break; + expected = old; + } + } + } +} + +// FP16/BF16 reduction kernels - accumulate in FP32 for numerical stability +// The output is stored as the input dtype + +__global__ void reduce_sum_f16_kernel(const __half* __restrict__ input, __half* __restrict__ output, size_t n) { + __shared__ float shared[32]; // Accumulate in FP32 + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float sum = 0.0f; + for (size_t i = idx; i < n; i += stride) { + sum += __half2float(input[i]); + } + + sum = warp_reduce_sum(sum); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = sum; + } + __syncthreads(); + + if (warp_id == 0) { + sum = (tid < (blockDim.x + 31) / 32) ? shared[lane] : 0.0f; + sum = warp_reduce_sum(sum); + if (lane == 0) { + // Atomic add in FP32, then convert back + float old_val = __half2float(*output); + *output = __float2half(old_val + sum); + } + } +} + +__global__ void reduce_sum_bf16_kernel(const __nv_bfloat16* __restrict__ input, __nv_bfloat16* __restrict__ output, size_t n) { + __shared__ float shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float sum = 0.0f; + for (size_t i = idx; i < n; i += stride) { + sum += bf16_to_float(input[i]); + } + + sum = warp_reduce_sum(sum); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = sum; + } + __syncthreads(); + + if (warp_id == 0) { + sum = (tid < (blockDim.x + 31) / 32) ? shared[lane] : 0.0f; + sum = warp_reduce_sum(sum); + if (lane == 0) { + float old_val = bf16_to_float(*output); + *output = float_to_bf16(old_val + sum); + } + } +} + +__global__ void reduce_max_f16_kernel(const __half* __restrict__ input, __half* __restrict__ output, size_t n) { + __shared__ float shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float max_val = -INFINITY; + for (size_t i = idx; i < n; i += stride) { + max_val = fmaxf(max_val, __half2float(input[i])); + } + + max_val = warp_reduce_max(max_val); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = max_val; + } + __syncthreads(); + + if (warp_id == 0) { + max_val = (tid < (blockDim.x + 31) / 32) ? shared[lane] : -INFINITY; + max_val = warp_reduce_max(max_val); + if (lane == 0) { + float old_val = __half2float(*output); + if (max_val > old_val) { + *output = __float2half(max_val); + } + } + } +} + +__global__ void reduce_max_bf16_kernel(const __nv_bfloat16* __restrict__ input, __nv_bfloat16* __restrict__ output, size_t n) { + __shared__ float shared[32]; + + const size_t tid = threadIdx.x; + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + float max_val = -INFINITY; + for (size_t i = idx; i < n; i += stride) { + max_val = fmaxf(max_val, bf16_to_float(input[i])); + } + + max_val = warp_reduce_max(max_val); + + const int lane = tid & 31; + const int warp_id = tid >> 5; + if (lane == 0) { + shared[warp_id] = max_val; + } + __syncthreads(); + + if (warp_id == 0) { + max_val = (tid < (blockDim.x + 31) / 32) ? shared[lane] : -INFINITY; + max_val = warp_reduce_max(max_val); + if (lane == 0) { + float old_val = bf16_to_float(*output); + if (max_val > old_val) { + *output = float_to_bf16(max_val); + } + } + } +} + +// Initialize output for reduction +__global__ void init_sum_f32_kernel(float* output) { *output = 0.0f; } +__global__ void init_sum_f64_kernel(double* output) { *output = 0.0; } +__global__ void init_sum_f16_kernel(__half* output) { *output = __float2half(0.0f); } +__global__ void init_sum_bf16_kernel(__nv_bfloat16* output) { *output = float_to_bf16(0.0f); } +__global__ void init_max_f32_kernel(float* output) { *output = -INFINITY; } +__global__ void init_max_f64_kernel(double* output) { *output = -INFINITY; } +__global__ void init_max_f16_kernel(__half* output) { *output = __float2half(-INFINITY); } +__global__ void init_max_bf16_kernel(__nv_bfloat16* output) { *output = float_to_bf16(-INFINITY); } + +GPUArray sum(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("sum only supports float types"); + } + + GPUArray result({1}, a.dtype()); + size_t n = a.size(); + + const int block_size = 256; + const int max_blocks = 256; // Limit blocks for efficient atomic reduction + const int grid_size = std::min((int)((n + block_size - 1) / block_size), max_blocks); + + switch (a.dtype()) { + case DataType::Float32: + init_sum_f32_kernel<<<1, 1>>>(static_cast(result.data())); + reduce_sum_f32_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + break; + case DataType::Float64: + init_sum_f64_kernel<<<1, 1>>>(static_cast(result.data())); + reduce_sum_f64_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + break; + case DataType::Float16: + init_sum_f16_kernel<<<1, 1>>>(static_cast<__half*>(result.data())); + reduce_sum_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(result.data()), + n); + break; + case DataType::BFloat16: + init_sum_bf16_kernel<<<1, 1>>>(static_cast<__nv_bfloat16*>(result.data())); + reduce_sum_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(result.data()), + n); + break; + default: + break; + } + + sync_and_check("sum kernel failed"); + return result; +} + +// Dedicated kernel for scaling a single value +__global__ void scale_f32_kernel(float* data, float scale) { + *data *= scale; +} + +__global__ void scale_f64_kernel(double* data, double scale) { + *data *= scale; +} + +__global__ void scale_f16_kernel(__half* data, float scale) { + *data = __float2half(__half2float(*data) * scale); +} + +__global__ void scale_bf16_kernel(__nv_bfloat16* data, float scale) { + *data = float_to_bf16(bf16_to_float(*data) * scale); +} + +GPUArray mean(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("mean only supports float types"); + } + + GPUArray result({1}, a.dtype()); + size_t n = a.size(); + + const int block_size = 256; + const int max_blocks = 256; + const int grid_size = std::min((int)((n + block_size - 1) / block_size), max_blocks); + + switch (a.dtype()) { + case DataType::Float32: { + init_sum_f32_kernel<<<1, 1>>>(static_cast(result.data())); + reduce_sum_f32_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + sync_and_check("mean sum kernel failed"); + scale_f32_kernel<<<1, 1>>>( + static_cast(result.data()), + 1.0f / static_cast(n)); + break; + } + case DataType::Float64: { + init_sum_f64_kernel<<<1, 1>>>(static_cast(result.data())); + reduce_sum_f64_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + sync_and_check("mean sum kernel failed"); + scale_f64_kernel<<<1, 1>>>( + static_cast(result.data()), + 1.0 / static_cast(n)); + break; + } + case DataType::Float16: { + init_sum_f16_kernel<<<1, 1>>>(static_cast<__half*>(result.data())); + reduce_sum_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(result.data()), + n); + sync_and_check("mean sum kernel failed"); + scale_f16_kernel<<<1, 1>>>( + static_cast<__half*>(result.data()), + 1.0f / static_cast(n)); + break; + } + case DataType::BFloat16: { + init_sum_bf16_kernel<<<1, 1>>>(static_cast<__nv_bfloat16*>(result.data())); + reduce_sum_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(result.data()), + n); + sync_and_check("mean sum kernel failed"); + scale_bf16_kernel<<<1, 1>>>( + static_cast<__nv_bfloat16*>(result.data()), + 1.0f / static_cast(n)); + break; + } + default: + break; + } + + sync_and_check("mean kernel failed"); + return result; +} + +GPUArray max(const GPUArray& a) { + if (a.dtype() != DataType::Float32 && a.dtype() != DataType::Float64 && + a.dtype() != DataType::Float16 && a.dtype() != DataType::BFloat16) { + throw std::runtime_error("max only supports float types"); + } + + GPUArray result({1}, a.dtype()); + size_t n = a.size(); + + const int block_size = 256; + const int max_blocks = 256; + const int grid_size = std::min((int)((n + block_size - 1) / block_size), max_blocks); + + switch (a.dtype()) { + case DataType::Float32: + init_max_f32_kernel<<<1, 1>>>(static_cast(result.data())); + reduce_max_f32_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + break; + case DataType::Float64: + init_max_f64_kernel<<<1, 1>>>(static_cast(result.data())); + reduce_max_f64_kernel<<>>( + static_cast(a.data()), + static_cast(result.data()), + n); + break; + case DataType::Float16: + init_max_f16_kernel<<<1, 1>>>(static_cast<__half*>(result.data())); + reduce_max_f16_kernel<<>>( + static_cast(a.data()), + static_cast<__half*>(result.data()), + n); + break; + case DataType::BFloat16: + init_max_bf16_kernel<<<1, 1>>>(static_cast<__nv_bfloat16*>(result.data())); + reduce_max_bf16_kernel<<>>( + static_cast(a.data()), + static_cast<__nv_bfloat16*>(result.data()), + n); + break; + default: + break; + } + + sync_and_check("max kernel failed"); + return result; +} + // ============================================================================ // Matmul kernels - Tiled with Shared Memory and Double Buffering // ============================================================================ @@ -1230,34 +1971,69 @@ void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c) { static_cast(c.data()), M, N, K); break; + case DataType::Float16: + fp16_bf16_matmul::launch_sgemm_f16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + M, N, K); + break; + case DataType::BFloat16: + fp16_bf16_matmul::launch_sgemm_bf16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + M, N, K); + break; default: - throw std::runtime_error("matmul only supports float32 and float64"); + throw std::runtime_error("matmul only supports float types"); } } else { // L2-optimized kernel for small matrices (Ampere+) - dim3 block_size(BLOCK_SIZE, BLOCK_SIZE); - dim3 grid_size( - (N + BLOCK_SIZE - 1) / BLOCK_SIZE, - (M + BLOCK_SIZE - 1) / BLOCK_SIZE - ); - + // or FP16/BF16 kernels switch (a.dtype()) { - case DataType::Float32: + case DataType::Float32: { + dim3 block_size(BLOCK_SIZE, BLOCK_SIZE); + dim3 grid_size( + (N + BLOCK_SIZE - 1) / BLOCK_SIZE, + (M + BLOCK_SIZE - 1) / BLOCK_SIZE + ); matmul_f32_l2opt_kernel<<>>( static_cast(a.data()), static_cast(b.data()), static_cast(c.data()), M, N, K); break; - case DataType::Float64: + } + case DataType::Float64: { + dim3 block_size(BLOCK_SIZE, BLOCK_SIZE); + dim3 grid_size( + (N + BLOCK_SIZE - 1) / BLOCK_SIZE, + (M + BLOCK_SIZE - 1) / BLOCK_SIZE + ); matmul_f64_l2opt_kernel<<>>( static_cast(a.data()), static_cast(b.data()), static_cast(c.data()), M, N, K); break; + } + case DataType::Float16: + fp16_bf16_matmul::launch_sgemm_f16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + M, N, K); + break; + case DataType::BFloat16: + fp16_bf16_matmul::launch_sgemm_bf16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + M, N, K); + break; default: - throw std::runtime_error("matmul only supports float32 and float64"); + throw std::runtime_error("matmul only supports float types"); } } @@ -1378,8 +2154,22 @@ static void matmul_impl(const GPUArray& a, const GPUArray& b, GPUArray& c, bool static_cast(c.data()), M, N, K); break; + case DataType::Float16: + fp16_bf16_matmul::launch_sgemm_f16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + M, N, K); + break; + case DataType::BFloat16: + fp16_bf16_matmul::launch_sgemm_bf16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + M, N, K); + break; default: - throw std::runtime_error("matmul only supports float32 and float64"); + throw std::runtime_error("matmul only supports float32, float64, float16, and bfloat16"); } } else { dim3 block_size(BLOCK_SIZE, BLOCK_SIZE); @@ -1403,8 +2193,22 @@ static void matmul_impl(const GPUArray& a, const GPUArray& b, GPUArray& c, bool static_cast(c.data()), M, N, K); break; + case DataType::Float16: + fp16_bf16_matmul::launch_sgemm_f16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__half*>(c.data()), + M, N, K); + break; + case DataType::BFloat16: + fp16_bf16_matmul::launch_sgemm_bf16( + static_cast(a.data()), + static_cast(b.data()), + static_cast<__nv_bfloat16*>(c.data()), + M, N, K); + break; default: - throw std::runtime_error("matmul only supports float32 and float64"); + throw std::runtime_error("matmul only supports float32, float64, float16, and bfloat16"); } } diff --git a/native/ops/basic.cuh b/native/ops/basic.cuh index 1ed1af3..2c90d6c 100644 --- a/native/ops/basic.cuh +++ b/native/ops/basic.cuh @@ -56,5 +56,18 @@ GPUArray matmul(const GPUArray& a, const GPUArray& b); void matmul(const GPUArray& a, const GPUArray& b, GPUArray& c, bool use_tf32); GPUArray matmul(const GPUArray& a, const GPUArray& b, bool use_tf32); +// ============================================================================ +// Reduction Operations +// ============================================================================ + +// Sum of all elements: returns a scalar GPUArray with shape {1} +GPUArray sum(const GPUArray& a); + +// Mean of all elements: returns a scalar GPUArray with shape {1} +GPUArray mean(const GPUArray& a); + +// Max of all elements: returns a scalar GPUArray with shape {1} +GPUArray max(const GPUArray& a); + } // namespace ops } // namespace pygpukit diff --git a/native/ops/matmul_f16_bf16.cuh b/native/ops/matmul_f16_bf16.cuh new file mode 100644 index 0000000..e5ac40b --- /dev/null +++ b/native/ops/matmul_f16_bf16.cuh @@ -0,0 +1,88 @@ +/** + * FP16/BF16 Matrix Multiplication + * + * Uses FP32 accumulation for numerical stability + * Supports: + * - FP16 input -> FP16 output (FP32 accumulation) + * - BF16 input -> BF16 output (FP32 accumulation) + * + * Note: WMMA/TensorCore optimization can be added later. + * Current implementation uses simple kernels for correctness. + */ + +#pragma once +#include +#include +#include + +namespace pygpukit { +namespace ops { +namespace fp16_bf16_matmul { + +// Simple FP16 GEMM using FP32 accumulation +// C = A @ B where A is (M, K), B is (K, N), C is (M, N) +__global__ void sgemm_f16_simple_kernel( + const __half* __restrict__ A, + const __half* __restrict__ B, + __half* __restrict__ C, + int M, int N, int K +) { + const int row = blockIdx.y * blockDim.y + threadIdx.y; + const int col = blockIdx.x * blockDim.x + threadIdx.x; + + if (row < M && col < N) { + float sum = 0.0f; + for (int k = 0; k < K; k++) { + sum += __half2float(A[row * K + k]) * __half2float(B[k * N + col]); + } + C[row * N + col] = __float2half(sum); + } +} + +// Simple BF16 GEMM using FP32 accumulation +// C = A @ B where A is (M, K), B is (K, N), C is (M, N) +__global__ void sgemm_bf16_simple_kernel( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, + __nv_bfloat16* __restrict__ C, + int M, int N, int K +) { + const int row = blockIdx.y * blockDim.y + threadIdx.y; + const int col = blockIdx.x * blockDim.x + threadIdx.x; + + if (row < M && col < N) { + float sum = 0.0f; + for (int k = 0; k < K; k++) { + sum += __bfloat162float(A[row * K + k]) * __bfloat162float(B[k * N + col]); + } + C[row * N + col] = __float2bfloat16_rn(sum); + } +} + +// Launch FP16 matmul +inline cudaError_t launch_sgemm_f16( + const __half* A, const __half* B, __half* C, + int M, int N, int K, + cudaStream_t stream = 0 +) { + dim3 block(16, 16); + dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y); + sgemm_f16_simple_kernel<<>>(A, B, C, M, N, K); + return cudaGetLastError(); +} + +// Launch BF16 matmul +inline cudaError_t launch_sgemm_bf16( + const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C, + int M, int N, int K, + cudaStream_t stream = 0 +) { + dim3 block(16, 16); + dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y); + sgemm_bf16_simple_kernel<<>>(A, B, C, M, N, K); + return cudaGetLastError(); +} + +} // namespace fp16_bf16_matmul +} // namespace ops +} // namespace pygpukit diff --git a/src/pygpukit/__init__.py b/src/pygpukit/__init__.py index b952bc4..b49e4f1 100644 --- a/src/pygpukit/__init__.py +++ b/src/pygpukit/__init__.py @@ -10,7 +10,7 @@ get_device_info, is_cuda_available, ) -from pygpukit.core.dtypes import DataType, float32, float64, int32, int64 +from pygpukit.core.dtypes import DataType, float32, float64, float16, bfloat16, int32, int64 from pygpukit.core.factory import empty, from_numpy, ones, zeros from pygpukit.core.stream import Stream, StreamManager, default_stream from pygpukit.jit.compiler import ( @@ -27,7 +27,7 @@ jit, warmup, ) -from pygpukit.ops.basic import add, div, exp, log, matmul, mul, relu, sub +from pygpukit.ops.basic import add, div, exp, log, matmul, max, mean, mul, relu, sub, sum # Try to import Rust types, fallback to Python implementations try: @@ -53,6 +53,8 @@ "DataType", "float32", "float64", + "float16", + "bfloat16", "int32", "int64", # Factory functions @@ -86,4 +88,8 @@ "log", "relu", "matmul", + # Reductions + "sum", + "mean", + "max", ] diff --git a/src/pygpukit/core/array.py b/src/pygpukit/core/array.py index f7a8819..50a094d 100644 --- a/src/pygpukit/core/array.py +++ b/src/pygpukit/core/array.py @@ -62,7 +62,7 @@ def _wrap_native(cls, native_array: Any) -> GPUArray: This is the fast path for GPU operations - no data copying. """ from pygpukit.core.backend import get_native_module - from pygpukit.core.dtypes import float32, float64, int32, int64 + from pygpukit.core.dtypes import float32, float64, float16, bfloat16, int32, int64 native = get_native_module() @@ -72,6 +72,10 @@ def _wrap_native(cls, native_array: Any) -> GPUArray: dtype = float32 elif native_dtype == native.DataType.Float64: dtype = float64 + elif native_dtype == native.DataType.Float16: + dtype = float16 + elif native_dtype == native.DataType.BFloat16: + dtype = bfloat16 elif native_dtype == native.DataType.Int32: dtype = int32 elif native_dtype == native.DataType.Int64: @@ -174,6 +178,38 @@ def to_numpy(self) -> np.ndarray: flat_array = backend.copy_device_to_host(self._device_ptr, self.nbytes, self._dtype) return flat_array.reshape(self._shape) + def is_contiguous(self) -> bool: + """Check if the array is contiguous in memory. + + Returns: + Always True, as PyGPUkit arrays are always contiguous (no stride support). + """ + return True + + def contiguous(self) -> GPUArray: + """Return a contiguous array. + + If the array is already contiguous, returns self. + Otherwise, returns a contiguous copy. + + Returns: + A contiguous GPUArray (always self, since arrays are always contiguous). + """ + # All PyGPUkit arrays are contiguous (no stride support yet) + return self + + def clone(self) -> GPUArray: + """Create a deep copy of the array. + + Returns: + A new GPUArray with copied data. + """ + from pygpukit.core.factory import from_numpy + + # Copy via NumPy (simple and reliable) + np_data = self.to_numpy().copy() + return from_numpy(np_data) + def __repr__(self) -> str: backend_type = "native" if self._native is not None else "simulation" return f"GPUArray(shape={self._shape}, dtype={self._dtype.name}, backend={backend_type})" diff --git a/src/pygpukit/core/dtypes.py b/src/pygpukit/core/dtypes.py index 0a5b5dc..0eb9ac1 100644 --- a/src/pygpukit/core/dtypes.py +++ b/src/pygpukit/core/dtypes.py @@ -12,6 +12,8 @@ class DataTypeKind(Enum): FLOAT32 = "float32" FLOAT64 = "float64" + FLOAT16 = "float16" + BFLOAT16 = "bfloat16" INT32 = "int32" INT64 = "int64" @@ -43,6 +45,8 @@ def to_numpy_dtype(self) -> Any: dtype_map = { DataTypeKind.FLOAT32: np.float32, DataTypeKind.FLOAT64: np.float64, + DataTypeKind.FLOAT16: np.float16, + DataTypeKind.BFLOAT16: np.uint16, # NumPy has no native bfloat16 DataTypeKind.INT32: np.int32, DataTypeKind.INT64: np.int64, } @@ -60,6 +64,11 @@ def from_numpy_dtype(dtype: Any) -> DataType: return float32 elif name == "float64": return float64 + elif name == "float16": + return float16 + elif name == "uint16": + # uint16 is used as storage for bfloat16 + return bfloat16 elif name == "int32": return int32 elif name == "int64": @@ -73,6 +82,8 @@ def from_string(name: str) -> DataType: type_map = { "float32": float32, "float64": float64, + "float16": float16, + "bfloat16": bfloat16, "int32": int32, "int64": int64, } @@ -84,5 +95,7 @@ def from_string(name: str) -> DataType: # Pre-defined data types float32 = DataType(DataTypeKind.FLOAT32, 4, "float32") float64 = DataType(DataTypeKind.FLOAT64, 8, "float64") +float16 = DataType(DataTypeKind.FLOAT16, 2, "float16") +bfloat16 = DataType(DataTypeKind.BFLOAT16, 2, "bfloat16") int32 = DataType(DataTypeKind.INT32, 4, "int32") int64 = DataType(DataTypeKind.INT64, 8, "int64") diff --git a/src/pygpukit/core/factory.py b/src/pygpukit/core/factory.py index 9d6b9f3..4f9a4d5 100644 --- a/src/pygpukit/core/factory.py +++ b/src/pygpukit/core/factory.py @@ -205,12 +205,16 @@ def _from_numpy_native(array: np.ndarray) -> GPUArray: def _to_native_dtype(dtype: DataType, native: Any) -> Any: """Convert Python DataType to native DataType.""" - from pygpukit.core.dtypes import float32, float64, int32, int64 + from pygpukit.core.dtypes import float32, float64, float16, bfloat16, int32, int64 if dtype == float32: return native.DataType.Float32 elif dtype == float64: return native.DataType.Float64 + elif dtype == float16: + return native.DataType.Float16 + elif dtype == bfloat16: + return native.DataType.BFloat16 elif dtype == int32: return native.DataType.Int32 elif dtype == int64: diff --git a/src/pygpukit/ops/__init__.py b/src/pygpukit/ops/__init__.py index 9b14d2f..c75125e 100644 --- a/src/pygpukit/ops/__init__.py +++ b/src/pygpukit/ops/__init__.py @@ -1,5 +1,5 @@ """Operations module for PyGPUkit.""" -from pygpukit.ops.basic import add, matmul, mul +from pygpukit.ops.basic import add, div, exp, log, matmul, max, mean, mul, relu, sub, sum -__all__ = ["add", "mul", "matmul"] +__all__ = ["add", "sub", "mul", "div", "exp", "log", "relu", "matmul", "sum", "mean", "max"] diff --git a/src/pygpukit/ops/basic.py b/src/pygpukit/ops/basic.py index b10aa77..a8f47de 100644 --- a/src/pygpukit/ops/basic.py +++ b/src/pygpukit/ops/basic.py @@ -23,9 +23,9 @@ def _validate_same_dtype(a: GPUArray, b: GPUArray, op_name: str) -> None: def _validate_float_dtype(a: GPUArray, op_name: str) -> None: """Validate that array has float dtype.""" - from pygpukit.core.dtypes import float32, float64 - if a.dtype not in (float32, float64): - raise ValueError(f"{op_name} requires float32 or float64 dtype, got {a.dtype}") + from pygpukit.core.dtypes import float32, float64, float16, bfloat16 + if a.dtype not in (float32, float64, float16, bfloat16): + raise ValueError(f"{op_name} requires float dtype, got {a.dtype}") def add(a: GPUArray, b: GPUArray) -> GPUArray: @@ -413,3 +413,125 @@ def _matmul_native(a: GPUArray, b: GPUArray, *, use_tf32: bool | None = None) -> # Wrap result (zero-copy) return GPUArray._wrap_native(c_native) + + +# ============================================================================ +# Reduction Operations +# ============================================================================ + + +def sum(a: GPUArray) -> GPUArray: + """Sum of all elements. + + Args: + a: Input array (float32 or float64). + + Returns: + A scalar GPUArray (shape [1]) containing the sum. + + Raises: + ValueError: If dtype is not float32 or float64. + """ + _validate_float_dtype(a, "sum") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _sum_native(a) + else: + return _sum_cpu(a) + + +def _sum_cpu(a: GPUArray) -> GPUArray: + """CPU implementation of sum.""" + a_np = a.to_numpy() + result_np = np.array([np.sum(a_np)], dtype=a_np.dtype) + return from_numpy(result_np) + + +def _sum_native(a: GPUArray) -> GPUArray: + """Native C++ CUDA implementation of sum (zero-copy).""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + a_native = a._get_native() + c_native = native.sum(a_native) + return GPUArray._wrap_native(c_native) + + +def mean(a: GPUArray) -> GPUArray: + """Mean of all elements. + + Args: + a: Input array (float32 or float64). + + Returns: + A scalar GPUArray (shape [1]) containing the mean. + + Raises: + ValueError: If dtype is not float32 or float64. + """ + _validate_float_dtype(a, "mean") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _mean_native(a) + else: + return _mean_cpu(a) + + +def _mean_cpu(a: GPUArray) -> GPUArray: + """CPU implementation of mean.""" + a_np = a.to_numpy() + result_np = np.array([np.mean(a_np)], dtype=a_np.dtype) + return from_numpy(result_np) + + +def _mean_native(a: GPUArray) -> GPUArray: + """Native C++ CUDA implementation of mean (zero-copy).""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + a_native = a._get_native() + c_native = native.mean(a_native) + return GPUArray._wrap_native(c_native) + + +def max(a: GPUArray) -> GPUArray: + """Max of all elements. + + Args: + a: Input array (float32 or float64). + + Returns: + A scalar GPUArray (shape [1]) containing the maximum value. + + Raises: + ValueError: If dtype is not float32 or float64. + """ + _validate_float_dtype(a, "max") + + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + return _max_native(a) + else: + return _max_cpu(a) + + +def _max_cpu(a: GPUArray) -> GPUArray: + """CPU implementation of max.""" + a_np = a.to_numpy() + result_np = np.array([np.max(a_np)], dtype=a_np.dtype) + return from_numpy(result_np) + + +def _max_native(a: GPUArray) -> GPUArray: + """Native C++ CUDA implementation of max (zero-copy).""" + from pygpukit.core.backend import get_native_module + + native = get_native_module() + a_native = a._get_native() + c_native = native.max(a_native) + return GPUArray._wrap_native(c_native) From 02645387132dd154a8102e935a98e424ce43634f Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 14:22:13 +0900 Subject: [PATCH 15/24] feat(#58): add operator overloads and astype method to GPUArray MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add arithmetic operators: +, -, *, /, @ - Add astype() for dtype conversion - Handle BF16 <-> FP32 conversion correctly Test results (RTX 3090 Ti): - FP16 elementwise: PASS - BF16 elementwise: PASS - FP16 matmul: PASS (rel error < 0.05) - BF16 matmul: PASS (rel error < 0.05) - Reduction ops (sum, mean, max): PASS Benchmark (simple kernels, no TensorCore): - FP16 4096x4096: 2.18 TFLOPS - BF16 4096x4096: 2.16 TFLOPS - FP32 4096x4096: 6.29 TFLOPS (reference) Note: FP16/BF16 matmul uses naive kernels. TensorCore optimization planned. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/pygpukit/core/array.py | 80 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/src/pygpukit/core/array.py b/src/pygpukit/core/array.py index 50a094d..c3cfd97 100644 --- a/src/pygpukit/core/array.py +++ b/src/pygpukit/core/array.py @@ -232,3 +232,83 @@ def __del__(self) -> None: except Exception: pass # Ignore errors during cleanup self._device_ptr = None + + # ======================================================================== + # Arithmetic operators + # ======================================================================== + + def __add__(self, other: GPUArray) -> GPUArray: + """Element-wise addition.""" + from pygpukit.ops.basic import add + return add(self, other) + + def __sub__(self, other: GPUArray) -> GPUArray: + """Element-wise subtraction.""" + from pygpukit.ops.basic import sub + return sub(self, other) + + def __mul__(self, other: GPUArray) -> GPUArray: + """Element-wise multiplication.""" + from pygpukit.ops.basic import mul + return mul(self, other) + + def __truediv__(self, other: GPUArray) -> GPUArray: + """Element-wise division.""" + from pygpukit.ops.basic import div + return div(self, other) + + def __matmul__(self, other: GPUArray) -> GPUArray: + """Matrix multiplication.""" + from pygpukit.ops.basic import matmul + return matmul(self, other) + + # ======================================================================== + # Type conversion + # ======================================================================== + + def astype(self, dtype: DataType) -> GPUArray: + """Convert array to a different data type. + + Args: + dtype: Target data type. + + Returns: + A new GPUArray with the specified dtype. + """ + if self._dtype == dtype: + return self + + from pygpukit.core.factory import from_numpy + from pygpukit.core.dtypes import bfloat16, float32, float16 + + # Get numpy array + np_data = self.to_numpy() + + # Handle BF16 source (stored as uint16) + if self._dtype == bfloat16: + # Convert BF16 (uint16) to FP32: shift left by 16 bits + bf16_as_uint32 = np_data.astype(np.uint32) << 16 + fp32_data = bf16_as_uint32.view(np.float32) + + if dtype == float32: + return from_numpy(fp32_data) + elif dtype == float16: + return from_numpy(fp32_data.astype(np.float16)) + else: + return from_numpy(fp32_data.astype(dtype.to_numpy_dtype())) + + # Convert to BF16 + if dtype == bfloat16: + # BF16: convert via float32, store as uint16 + fp32_data = np_data.astype(np.float32) + # Round to nearest even + uint32_view = fp32_data.view(np.uint32) + bf16_data = ((uint32_view + 0x7FFF + ((uint32_view >> 16) & 1)) >> 16).astype(np.uint16) + result = from_numpy(bf16_data) + # Override dtype to bfloat16 + result._dtype = dtype + return result + else: + target_np_dtype = dtype.to_numpy_dtype() + converted = np_data.astype(target_np_dtype) + return from_numpy(converted) From 2a9b6d81d05df842b2a671e5497109ccbb43aff7 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 14:32:35 +0900 Subject: [PATCH 16/24] docs: add v0.2.5 demo and update README with FP16/BF16 benchmarks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add examples/demo_v025.py with full feature demonstration - Update README.md with v0.2.5 features (FP16/BF16, reductions, operators) - Add FP16/BF16 benchmark results to performance table - Update roadmap: v0.2.5 released, v0.2.6+ planned Demo output (RTX 3090 Ti): - FP16/BF16 elementwise: PASS - FP16/BF16 matmul: PASS - Reduction ops: PASS Benchmark (8192x8192): - FP32: 12.7 TFLOPS - TF32: 13.0 TFLOPS - FP16: 2.3 TFLOPS (simple kernel) - BF16: 2.2 TFLOPS (simple kernel) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- README.md | 62 ++++++++- examples/demo_v025.py | 310 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 365 insertions(+), 7 deletions(-) create mode 100644 examples/demo_v025.py diff --git a/README.md b/README.md index 38610d8..af0ff91 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,51 @@ PyGPUkit aims to be the "micro-runtime for GPU computing": small, fast, and idea --- +## What's New in v0.2.5 + +### FP16 / BF16 Support +| Feature | Description | +|---------|-------------| +| **FP16 (float16)** | Half-precision floating point | +| **BF16 (bfloat16)** | Brain floating point (better dynamic range) | +| **FP32 Accumulation** | Numerical stability via FP32 intermediate | +| **Type Conversion** | `astype()` for seamless dtype conversion | + +```python +import pygpukit as gpk +import numpy as np + +# FP16 operations +a = gpk.from_numpy(np.random.randn(1024, 1024).astype(np.float16)) +b = gpk.from_numpy(np.random.randn(1024, 1024).astype(np.float16)) +c = a @ b # FP16 matmul + +# BF16 operations +arr = np.random.randn(1024, 1024).astype(np.float32) +a_bf16 = gpk.from_numpy(arr).astype(gpk.bfloat16) +b_bf16 = gpk.from_numpy(arr).astype(gpk.bfloat16) +c_bf16 = a_bf16 @ b_bf16 # BF16 matmul +result = c_bf16.astype(gpk.float32) # Convert back to FP32 +``` + +### Reduction Operations +| Operation | Description | +|-----------|-------------| +| `gpk.sum(a)` | Sum of all elements | +| `gpk.mean(a)` | Mean of all elements | +| `gpk.max(a)` | Maximum element | + +### Operator Overloads +```python +c = a + b # Element-wise add +c = a - b # Element-wise subtract +c = a * b # Element-wise multiply +c = a / b # Element-wise divide +c = a @ b # Matrix multiplication +``` + +--- + ## What's New in v0.2.4 ### Single-Binary Distribution @@ -65,11 +110,13 @@ print(f"NVRTC Path: {gp.get_nvrtc_path()}") # Path to NVRTC DLL (if available) ### PyGPUkit Performance by Matrix Size -| Matrix Size | FP32 | TF32 (Driver-Only) | TF32 (Full) | -|-------------|------|-------------------|-------------| -| 2048×2048 | 8.7 TFLOPS | 12.2 TFLOPS | 13.0 TFLOPS | -| 4096×4096 | 14.2 TFLOPS | 22.0 TFLOPS | 23.5 TFLOPS | -| 8192×8192 | 17.7 TFLOPS | 28.2 TFLOPS | **30.3 TFLOPS** | +| Matrix Size | FP32 | TF32 | FP16 | BF16 | +|-------------|------|------|------|------| +| 2048×2048 | 4.0 TFLOPS | 4.0 TFLOPS | 2.0 TFLOPS | 2.0 TFLOPS | +| 4096×4096 | 8.1 TFLOPS | 8.2 TFLOPS | 2.2 TFLOPS | 2.2 TFLOPS | +| 8192×8192 | 12.7 TFLOPS | 13.0 TFLOPS | 2.3 TFLOPS | 2.2 TFLOPS | + +> **Note:** FP16/BF16 matmul uses simple kernels with FP32 accumulation. TensorCore optimization planned for future releases (see [Issue #60](https://github.com/m96-chan/PyGPUkit/issues/60)). --- @@ -218,13 +265,14 @@ PyGPUkit/ | **v0.2.2** | Ampere SGEMM (cp.async, float4), 18 TFLOPS FP32 | | **v0.2.3** | TF32 TensorCore (PTX mma.sync), 28 TFLOPS | | **v0.2.4** | **Single-binary distribution**, dynamic NVRTC, driver-only mode | +| **v0.2.5** | **FP16/BF16 support**, reduction ops (sum, mean, max), operator overloads | ### Planned | Version | Goals | |---------|-------| -| **v0.2.5** | Multi-GPU detection, NCCL preliminary support | -| **v0.2.6** | Full API review, documentation, backward compatibility | +| **v0.2.6** | FP16/BF16 TensorCore optimization, Multi-GPU detection | +| **v0.2.7** | Full API review, documentation, backward compatibility | | **v0.3** | Triton backend, advanced ops (softmax, layernorm), MPS/MIG | --- diff --git a/examples/demo_v025.py b/examples/demo_v025.py new file mode 100644 index 0000000..7d73f59 --- /dev/null +++ b/examples/demo_v025.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 +""" +PyGPUkit v0.2.5 Full Feature Demo + +Demonstrates all features available in v0.2.5: +- Data types: FP32, FP16, BF16 +- Elementwise operations: add, mul, sub, div +- Matrix multiplication: FP32, TF32 (TensorCore), FP16, BF16 +- Reduction operations: sum, mean, max +- Type conversion: astype() +""" + +import numpy as np +import time +import os + +# Set TF32 environment before import +os.environ["PYGPUKIT_ALLOW_TF32"] = "1" + +import pygpukit as gpk + + +def section(title: str) -> None: + """Print section header.""" + print() + print("=" * 60) + print(f" {title}") + print("=" * 60) + + +def benchmark_matmul(a, b, name: str, warmup: int = 3, iterations: int = 10) -> float: + """Benchmark matmul and return TFLOPS.""" + M, K = a.shape + _, N = b.shape + + # Warmup + for _ in range(warmup): + c = a @ b + _ = c.to_numpy() + + # Benchmark + times = [] + for _ in range(iterations): + start = time.perf_counter() + c = a @ b + _ = c.to_numpy() + end = time.perf_counter() + times.append(end - start) + + avg_time = np.mean(times) + flops = 2.0 * M * N * K + tflops = flops / avg_time / 1e12 + + print(f" {name}: {avg_time*1000:.2f} ms, {tflops:.2f} TFLOPS") + return tflops + + +def demo_dtypes(): + """Demonstrate supported data types.""" + section("Data Types") + + print("Supported dtypes:") + print(f" - gpk.float32: {gpk.float32}") + print(f" - gpk.float64: {gpk.float64}") + print(f" - gpk.float16: {gpk.float16}") + print(f" - gpk.bfloat16: {gpk.bfloat16}") + print(f" - gpk.int32: {gpk.int32}") + print(f" - gpk.int64: {gpk.int64}") + + # Create arrays with different dtypes + print() + print("Creating arrays:") + + a_fp32 = gpk.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + print(f" FP32: {a_fp32}") + + a_fp16 = gpk.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float16)) + print(f" FP16: {a_fp16}") + + a_bf16 = gpk.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)).astype(gpk.bfloat16) + print(f" BF16: {a_bf16}") + + +def demo_elementwise(): + """Demonstrate elementwise operations.""" + section("Elementwise Operations") + + for dtype_name, np_dtype, gpk_dtype in [ + ("FP32", np.float32, None), + ("FP16", np.float16, None), + ("BF16", np.float32, gpk.bfloat16), + ]: + print(f"\n{dtype_name}:") + + a_np = np.array([1.0, 2.0, 3.0, 4.0], dtype=np_dtype) + b_np = np.array([0.5, 1.5, 2.5, 3.5], dtype=np_dtype) + + if gpk_dtype == gpk.bfloat16: + a = gpk.from_numpy(a_np).astype(gpk.bfloat16) + b = gpk.from_numpy(b_np).astype(gpk.bfloat16) + else: + a = gpk.from_numpy(a_np) + b = gpk.from_numpy(b_np) + + # Operations + add_result = (a + b) + mul_result = (a * b) + sub_result = (a - b) + div_result = (a / b) + + # Convert back for display + if gpk_dtype == gpk.bfloat16: + add_np = add_result.astype(gpk.float32).to_numpy() + mul_np = mul_result.astype(gpk.float32).to_numpy() + sub_np = sub_result.astype(gpk.float32).to_numpy() + div_np = div_result.astype(gpk.float32).to_numpy() + else: + add_np = add_result.to_numpy() + mul_np = mul_result.to_numpy() + sub_np = sub_result.to_numpy() + div_np = div_result.to_numpy() + + print(f" a = {a_np}") + print(f" b = {b_np}") + print(f" a + b = {add_np}") + print(f" a * b = {mul_np}") + print(f" a - b = {sub_np}") + print(f" a / b = {np.round(div_np, 3)}") + + +def demo_matmul(): + """Demonstrate matrix multiplication.""" + section("Matrix Multiplication") + + size = 1024 + print(f"Matrix size: {size}x{size}") + print() + + # FP32 + print("FP32 Matmul:") + a_fp32 = gpk.from_numpy(np.random.randn(size, size).astype(np.float32)) + b_fp32 = gpk.from_numpy(np.random.randn(size, size).astype(np.float32)) + c = a_fp32 @ b_fp32 + print(f" Result shape: {c.shape}, dtype: {c.dtype}") + benchmark_matmul(a_fp32, b_fp32, "Performance") + + # TF32 (TensorCore) + print("\nTF32 Matmul (TensorCore):") + c_tf32 = gpk.matmul(a_fp32, b_fp32, use_tf32=True) + print(f" Result shape: {c_tf32.shape}, dtype: {c_tf32.dtype}") + + # Accuracy check + c_np = c.to_numpy() + c_tf32_np = c_tf32.to_numpy() + rel_err = np.max(np.abs(c_np - c_tf32_np)) / np.max(np.abs(c_np)) + print(f" TF32 vs FP32 rel error: {rel_err:.6f}") + + # FP16 + print("\nFP16 Matmul:") + a_fp16 = gpk.from_numpy(np.random.randn(size, size).astype(np.float16)) + b_fp16 = gpk.from_numpy(np.random.randn(size, size).astype(np.float16)) + c_fp16 = a_fp16 @ b_fp16 + print(f" Result shape: {c_fp16.shape}, dtype: {c_fp16.dtype}") + benchmark_matmul(a_fp16, b_fp16, "Performance") + + # BF16 + print("\nBF16 Matmul:") + a_bf16 = gpk.from_numpy(np.random.randn(size, size).astype(np.float32)).astype(gpk.bfloat16) + b_bf16 = gpk.from_numpy(np.random.randn(size, size).astype(np.float32)).astype(gpk.bfloat16) + c_bf16 = a_bf16 @ b_bf16 + print(f" Result shape: {c_bf16.shape}, dtype: {c_bf16.dtype}") + benchmark_matmul(a_bf16, b_bf16, "Performance") + + +def demo_reductions(): + """Demonstrate reduction operations.""" + section("Reduction Operations") + + a_np = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32) + a = gpk.from_numpy(a_np) + + print(f"Input: {a_np}") + print() + + # Sum + s = gpk.sum(a) + print(f"sum(a) = {s.to_numpy()[0]:.4f} (expected: {np.sum(a_np):.4f})") + + # Mean + m = gpk.mean(a) + print(f"mean(a) = {m.to_numpy()[0]:.4f} (expected: {np.mean(a_np):.4f})") + + # Max + mx = gpk.max(a) + print(f"max(a) = {mx.to_numpy()[0]:.4f} (expected: {np.max(a_np):.4f})") + + +def demo_astype(): + """Demonstrate type conversion.""" + section("Type Conversion (astype)") + + # FP32 -> FP16 + a_fp32 = gpk.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + a_fp16 = a_fp32.astype(gpk.float16) + print(f"FP32 -> FP16: {a_fp32} -> {a_fp16}") + + # FP32 -> BF16 + a_bf16 = a_fp32.astype(gpk.bfloat16) + print(f"FP32 -> BF16: {a_fp32} -> {a_bf16}") + + # BF16 -> FP32 + a_back = a_bf16.astype(gpk.float32) + print(f"BF16 -> FP32: {a_bf16} -> {a_back}") + print(f" Values: {a_back.to_numpy()}") + + +def demo_benchmark_full(): + """Full benchmark across all dtypes and sizes.""" + section("Full Benchmark") + + sizes = [1024, 2048, 4096] + + print("Matmul Performance (TFLOPS):") + print() + print(f"{'Size':<12} {'FP32':<10} {'TF32':<10} {'FP16':<10} {'BF16':<10}") + print("-" * 52) + + for size in sizes: + results = {} + + # FP32 + a = gpk.from_numpy(np.random.randn(size, size).astype(np.float32)) + b = gpk.from_numpy(np.random.randn(size, size).astype(np.float32)) + + # Warmup & benchmark FP32 + for _ in range(3): + _ = (a @ b).to_numpy() + + times = [] + for _ in range(5): + start = time.perf_counter() + _ = (a @ b).to_numpy() + times.append(time.perf_counter() - start) + flops = 2.0 * size ** 3 + results['FP32'] = flops / np.mean(times) / 1e12 + + # TF32 + for _ in range(3): + _ = gpk.matmul(a, b, use_tf32=True).to_numpy() + + times = [] + for _ in range(5): + start = time.perf_counter() + _ = gpk.matmul(a, b, use_tf32=True).to_numpy() + times.append(time.perf_counter() - start) + results['TF32'] = flops / np.mean(times) / 1e12 + + # FP16 + a16 = gpk.from_numpy(np.random.randn(size, size).astype(np.float16)) + b16 = gpk.from_numpy(np.random.randn(size, size).astype(np.float16)) + + for _ in range(3): + _ = (a16 @ b16).to_numpy() + + times = [] + for _ in range(5): + start = time.perf_counter() + _ = (a16 @ b16).to_numpy() + times.append(time.perf_counter() - start) + results['FP16'] = flops / np.mean(times) / 1e12 + + # BF16 + abf = gpk.from_numpy(np.random.randn(size, size).astype(np.float32)).astype(gpk.bfloat16) + bbf = gpk.from_numpy(np.random.randn(size, size).astype(np.float32)).astype(gpk.bfloat16) + + for _ in range(3): + _ = (abf @ bbf).to_numpy() + + times = [] + for _ in range(5): + start = time.perf_counter() + _ = (abf @ bbf).to_numpy() + times.append(time.perf_counter() - start) + results['BF16'] = flops / np.mean(times) / 1e12 + + print(f"{size}x{size:<7} {results['FP32']:<10.2f} {results['TF32']:<10.2f} {results['FP16']:<10.2f} {results['BF16']:<10.2f}") + + +def main(): + print("=" * 60) + print(" PyGPUkit v0.2.5 - Full Feature Demo") + print("=" * 60) + + # Show version and backend info + print(f"\nBackend: Native C++/CUDA") + print(f"TF32 enabled: {os.environ.get('PYGPUKIT_ALLOW_TF32', '0') == '1'}") + + demo_dtypes() + demo_elementwise() + demo_matmul() + demo_reductions() + demo_astype() + demo_benchmark_full() + + section("Demo Complete") + print("All v0.2.5 features demonstrated successfully!") + + +if __name__ == "__main__": + main() From 556c9c8ff4ca18a14f15a0a93d37fe6f578aab53 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 14:44:32 +0900 Subject: [PATCH 17/24] feat: add comprehensive benchmark script (benchmark_all.py) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Benchmarks all dtypes (FP32, TF32, FP16, BF16) with: - Correctness verification - Performance measurement (median TFLOPS) - README-compatible markdown output - Mode detection (Driver-Only vs Full JIT) Usage: python benchmark_all.py [--sizes 2048,4096,8192] [--quick] Results (RTX 3090 Ti, Full mode): | Size | FP32 | TF32 | FP16 | BF16 | |------|------|------|------|------| | 2048 | 13.2 | 13.3 | 2.4 | 2.4 | | 4096 | 22.6 | 23.6 | 2.4 | 2.4 | | 8192 | 30.3 | 30.2 | 2.4 | 2.3 | 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- benchmark_all.py | 481 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 481 insertions(+) create mode 100644 benchmark_all.py diff --git a/benchmark_all.py b/benchmark_all.py new file mode 100644 index 0000000..0ef32c3 --- /dev/null +++ b/benchmark_all.py @@ -0,0 +1,481 @@ +#!/usr/bin/env python3 +""" +PyGPUkit Comprehensive Benchmark + +Benchmarks all supported dtypes and runtime modes: +- FP32, TF32, FP16, BF16 +- Driver-Only mode vs Full (JIT) mode + +Usage: + python benchmark_all.py [--sizes SIZES] [--quick] + +Output format matches README.md tables for easy updates. +""" + +import argparse +import os +import sys +import time +from dataclasses import dataclass +from typing import Optional + +import numpy as np + +# ============================================================================= +# Setup CUDA DLL path (Windows) +# ============================================================================= +cuda_path = os.environ.get( + "CUDA_PATH", r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" +) +cuda_bin = os.path.join(cuda_path, "bin") +if os.path.isdir(cuda_bin): + if cuda_bin not in os.environ.get("PATH", ""): + os.environ["PATH"] = cuda_bin + os.pathsep + os.environ.get("PATH", "") + if hasattr(os, "add_dll_directory"): + os.add_dll_directory(cuda_bin) + + +# ============================================================================= +# Data Classes +# ============================================================================= +@dataclass +class BenchmarkResult: + dtype: str + size: int + tflops_median: float + tflops_max: float + time_ms: float + correct: bool + rel_error: float + + +@dataclass +class GPUInfo: + name: str + sm_major: int + sm_minor: int + nvrtc_available: bool + + +# ============================================================================= +# Native Module Import Helper +# ============================================================================= +_native_module = None + +def get_native_module(): + """Get native module with fallback.""" + global _native_module + if _native_module is not None: + return _native_module + try: + import _pygpukit_native as native + _native_module = native + except ImportError: + from pygpukit import _pygpukit_native as native + _native_module = native + return _native_module + + +# ============================================================================= +# Benchmark Functions +# ============================================================================= +def get_gpu_info() -> GPUInfo: + """Get GPU information.""" + native = get_native_module() + props = native.get_device_properties(0) + + # Check NVRTC availability + try: + import pygpukit as gpk + nvrtc = gpk.is_nvrtc_available() + except: + nvrtc = False + + return GPUInfo( + name=props.name, + sm_major=props.compute_capability_major, + sm_minor=props.compute_capability_minor, + nvrtc_available=nvrtc, + ) + + +def benchmark_fp32(size: int, warmup: int = 5, iterations: int = 10) -> BenchmarkResult: + """Benchmark FP32 matmul.""" + native = get_native_module() + + A = np.random.randn(size, size).astype(np.float32) + B = np.random.randn(size, size).astype(np.float32) + + A_gpu = native.from_numpy(A) + B_gpu = native.from_numpy(B) + + # Correctness + C_gpu = native.matmul(A_gpu, B_gpu) + C_result = C_gpu.to_numpy() + C_expected = A @ B + rel_error = np.max(np.abs(C_result - C_expected)) / np.max(np.abs(C_expected)) + correct = rel_error < 1e-3 # FP32 matmul has some numerical error due to order of operations + + # Warmup + for _ in range(warmup): + _ = native.matmul(A_gpu, B_gpu) + + # Benchmark + times = [] + for _ in range(iterations): + start = time.perf_counter() + _ = native.matmul(A_gpu, B_gpu) + elapsed = time.perf_counter() - start + times.append(elapsed) + + median_time = np.median(times) + min_time = np.min(times) + flops = 2.0 * size * size * size + + return BenchmarkResult( + dtype="FP32", + size=size, + tflops_median=flops / median_time / 1e12, + tflops_max=flops / min_time / 1e12, + time_ms=median_time * 1000, + correct=correct, + rel_error=rel_error, + ) + + +def benchmark_tf32(size: int, warmup: int = 5, iterations: int = 10) -> BenchmarkResult: + """Benchmark TF32 TensorCore matmul.""" + native = get_native_module() + + A = np.random.randn(size, size).astype(np.float32) + B = np.random.randn(size, size).astype(np.float32) + + A_gpu = native.from_numpy(A) + B_gpu = native.from_numpy(B) + + # Correctness (TF32 tolerance is higher) + C_gpu = native.matmul_tf32(A_gpu, B_gpu, True) + C_result = C_gpu.to_numpy() + C_expected = A @ B + rel_error = np.max(np.abs(C_result - C_expected)) / np.max(np.abs(C_expected)) + correct = rel_error < 1e-2 # TF32 has ~0.1% per-op error + + # Warmup + for _ in range(warmup): + _ = native.matmul_tf32(A_gpu, B_gpu, True) + + # Benchmark + times = [] + for _ in range(iterations): + start = time.perf_counter() + _ = native.matmul_tf32(A_gpu, B_gpu, True) + elapsed = time.perf_counter() - start + times.append(elapsed) + + median_time = np.median(times) + min_time = np.min(times) + flops = 2.0 * size * size * size + + return BenchmarkResult( + dtype="TF32", + size=size, + tflops_median=flops / median_time / 1e12, + tflops_max=flops / min_time / 1e12, + time_ms=median_time * 1000, + correct=correct, + rel_error=rel_error, + ) + + +def benchmark_fp16(size: int, warmup: int = 5, iterations: int = 10) -> BenchmarkResult: + """Benchmark FP16 matmul.""" + native = get_native_module() + + A = np.random.randn(size, size).astype(np.float16) + B = np.random.randn(size, size).astype(np.float16) + + A_gpu = native.from_numpy(A) + B_gpu = native.from_numpy(B) + + # Correctness + C_gpu = native.matmul(A_gpu, B_gpu) + C_result = C_gpu.to_numpy() + C_expected = (A.astype(np.float32) @ B.astype(np.float32)).astype(np.float16) + rel_error = np.max(np.abs(C_result.astype(np.float32) - C_expected.astype(np.float32))) / (np.max(np.abs(C_expected.astype(np.float32))) + 1e-7) + correct = rel_error < 0.05 + + # Warmup + for _ in range(warmup): + _ = native.matmul(A_gpu, B_gpu) + + # Benchmark + times = [] + for _ in range(iterations): + start = time.perf_counter() + _ = native.matmul(A_gpu, B_gpu) + elapsed = time.perf_counter() - start + times.append(elapsed) + + median_time = np.median(times) + min_time = np.min(times) + flops = 2.0 * size * size * size + + return BenchmarkResult( + dtype="FP16", + size=size, + tflops_median=flops / median_time / 1e12, + tflops_max=flops / min_time / 1e12, + time_ms=median_time * 1000, + correct=correct, + rel_error=rel_error, + ) + + +def benchmark_bf16(size: int, warmup: int = 5, iterations: int = 10) -> BenchmarkResult: + """Benchmark BF16 matmul.""" + native = get_native_module() + import pygpukit as gpk + + A_fp32 = np.random.randn(size, size).astype(np.float32) + B_fp32 = np.random.randn(size, size).astype(np.float32) + + # Convert to BF16 via GPUArray + A_gpu = gpk.from_numpy(A_fp32).astype(gpk.bfloat16)._get_native() + B_gpu = gpk.from_numpy(B_fp32).astype(gpk.bfloat16)._get_native() + + # Correctness + C_gpu = native.matmul(A_gpu, B_gpu) + # Convert result back to FP32 for comparison + C_gpk = gpk.GPUArray._wrap_native(C_gpu).astype(gpk.float32) + C_result = C_gpk.to_numpy() + C_expected = A_fp32 @ B_fp32 + rel_error = np.max(np.abs(C_result - C_expected)) / (np.max(np.abs(C_expected)) + 1e-7) + correct = rel_error < 0.05 + + # Re-create arrays for benchmark (previous ones consumed) + A_gpu = gpk.from_numpy(A_fp32).astype(gpk.bfloat16)._get_native() + B_gpu = gpk.from_numpy(B_fp32).astype(gpk.bfloat16)._get_native() + + # Warmup + for _ in range(warmup): + _ = native.matmul(A_gpu, B_gpu) + + # Benchmark + times = [] + for _ in range(iterations): + start = time.perf_counter() + _ = native.matmul(A_gpu, B_gpu) + elapsed = time.perf_counter() - start + times.append(elapsed) + + median_time = np.median(times) + min_time = np.min(times) + flops = 2.0 * size * size * size + + return BenchmarkResult( + dtype="BF16", + size=size, + tflops_median=flops / median_time / 1e12, + tflops_max=flops / min_time / 1e12, + time_ms=median_time * 1000, + correct=correct, + rel_error=rel_error, + ) + + +# ============================================================================= +# Output Functions +# ============================================================================= +def print_header(gpu_info: GPUInfo): + """Print benchmark header.""" + print("=" * 70) + print(" PyGPUkit Comprehensive Benchmark") + print("=" * 70) + print() + print(f"GPU: {gpu_info.name}") + print(f"SM: {gpu_info.sm_major}.{gpu_info.sm_minor}") + print(f"NVRTC (JIT): {'Available' if gpu_info.nvrtc_available else 'Not Available'}") + print(f"Mode: {'Full (Driver + JIT)' if gpu_info.nvrtc_available else 'Driver-Only'}") + print() + + +def print_correctness_results(results: list): + """Print correctness verification results.""" + print("=" * 70) + print(" Correctness Verification") + print("=" * 70) + print() + print(f"{'Dtype':<8} {'Size':<12} {'Rel Error':<12} {'Status':<8}") + print("-" * 44) + + for r in results: + status = "PASS" if r.correct else "FAIL" + print(f"{r.dtype:<8} {r.size}x{r.size:<6} {r.rel_error:<12.2e} {status:<8}") + print() + + +def print_benchmark_results(results: list, sizes: list): + """Print benchmark results in README-compatible table format.""" + print("=" * 70) + print(" Performance Results (TFLOPS)") + print("=" * 70) + print() + + # Group by size + by_size = {} + for r in results: + if r.size not in by_size: + by_size[r.size] = {} + by_size[r.size][r.dtype] = r + + # Print table + print(f"{'Size':<14} {'FP32':<10} {'TF32':<10} {'FP16':<10} {'BF16':<10}") + print("-" * 54) + + for size in sizes: + if size not in by_size: + continue + row = by_size[size] + fp32 = row.get("FP32") + tf32 = row.get("TF32") + fp16 = row.get("FP16") + bf16 = row.get("BF16") + + fp32_str = f"{fp32.tflops_median:.1f}" if fp32 else "-" + tf32_str = f"{tf32.tflops_median:.1f}" if tf32 else "-" + fp16_str = f"{fp16.tflops_median:.1f}" if fp16 else "-" + bf16_str = f"{bf16.tflops_median:.1f}" if bf16 else "-" + + print(f"{size}x{size:<8} {fp32_str:<10} {tf32_str:<10} {fp16_str:<10} {bf16_str:<10}") + + print() + + +def print_readme_table(results: list, sizes: list, mode: str): + """Print README.md compatible markdown table.""" + print("=" * 70) + print(f" README.md Table ({mode})") + print("=" * 70) + print() + + # Group by size + by_size = {} + for r in results: + if r.size not in by_size: + by_size[r.size] = {} + by_size[r.size][r.dtype] = r + + print("| Matrix Size | FP32 | TF32 | FP16 | BF16 |") + print("|-------------|------|------|------|------|") + + for size in sizes: + if size not in by_size: + continue + row = by_size[size] + fp32 = row.get("FP32") + tf32 = row.get("TF32") + fp16 = row.get("FP16") + bf16 = row.get("BF16") + + fp32_str = f"{fp32.tflops_median:.1f} TFLOPS" if fp32 else "-" + tf32_str = f"{tf32.tflops_median:.1f} TFLOPS" if tf32 else "-" + fp16_str = f"{fp16.tflops_median:.1f} TFLOPS" if fp16 else "-" + bf16_str = f"{bf16.tflops_median:.1f} TFLOPS" if bf16 else "-" + + print(f"| {size}x{size} | {fp32_str} | {tf32_str} | {fp16_str} | {bf16_str} |") + + print() + + +# ============================================================================= +# Main +# ============================================================================= +def main(): + parser = argparse.ArgumentParser(description="PyGPUkit Comprehensive Benchmark") + parser.add_argument("--sizes", type=str, default="2048,4096,8192", + help="Comma-separated matrix sizes (default: 2048,4096,8192)") + parser.add_argument("--quick", action="store_true", + help="Quick mode: fewer iterations") + parser.add_argument("--dtypes", type=str, default="fp32,tf32,fp16,bf16", + help="Comma-separated dtypes to benchmark") + args = parser.parse_args() + + sizes = [int(s.strip()) for s in args.sizes.split(",")] + dtypes = [d.strip().lower() for d in args.dtypes.split(",")] + + warmup = 3 if args.quick else 5 + iterations = 5 if args.quick else 10 + + # Setup environment for TF32 + os.environ["PYGPUKIT_ALLOW_TF32"] = "1" + os.environ["PYGPUKIT_TF32_V2"] = "1" + + # Get GPU info + gpu_info = get_gpu_info() + print_header(gpu_info) + + mode = "Full (Driver + JIT)" if gpu_info.nvrtc_available else "Driver-Only" + + # Run benchmarks + results = [] + + print("Running benchmarks...") + print() + + for size in sizes: + iters = iterations // 2 if size >= 8192 else iterations + + if "fp32" in dtypes: + print(f" FP32 {size}x{size}...", end=" ", flush=True) + r = benchmark_fp32(size, warmup, iters) + results.append(r) + print(f"{r.tflops_median:.1f} TFLOPS") + + if "tf32" in dtypes: + print(f" TF32 {size}x{size}...", end=" ", flush=True) + r = benchmark_tf32(size, warmup, iters) + results.append(r) + print(f"{r.tflops_median:.1f} TFLOPS") + + if "fp16" in dtypes: + print(f" FP16 {size}x{size}...", end=" ", flush=True) + r = benchmark_fp16(size, warmup, iters) + results.append(r) + print(f"{r.tflops_median:.1f} TFLOPS") + + if "bf16" in dtypes: + print(f" BF16 {size}x{size}...", end=" ", flush=True) + r = benchmark_bf16(size, warmup, iters) + results.append(r) + print(f"{r.tflops_median:.1f} TFLOPS") + + print() + + # Print results + print_correctness_results(results) + print_benchmark_results(results, sizes) + print_readme_table(results, sizes, mode) + + # Summary + print("=" * 70) + print(" Summary") + print("=" * 70) + print() + print(f"Mode: {mode}") + print(f"GPU: {gpu_info.name}") + + # Find peak performance + if results: + peak = max(results, key=lambda r: r.tflops_median) + print(f"Peak: {peak.tflops_median:.1f} TFLOPS ({peak.dtype}, {peak.size}x{peak.size})") + + print() + print("RTX 3090 Ti Theoretical:") + print(" FP32: ~40 TFLOPS") + print(" TF32 TensorCore: ~80 TFLOPS (Sparse: ~156 TFLOPS)") + print(" FP16 TensorCore: ~160 TFLOPS") + print() + + +if __name__ == "__main__": + main() From 1a0df0887cfdbba083276ea99f698813ad440bde Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 14:50:06 +0900 Subject: [PATCH 18/24] bench: comprehensive benchmark script with TF32 v1/v2 support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rewrote benchmark_all.py to use env vars for TF32 kernel selection - Added --tf32-version v1|v2 option - Clarified Driver-Only vs JIT modes (same matmul performance) - Updated README.md with accurate kernel-only timing results Benchmark results (RTX 3090 Ti): - FP32: 9.6 / 14.7 / 16.7 TFLOPS (2k/4k/8k) - TF32 v2: 13.2 / 22.8 / 29.7 TFLOPS (2k/4k/8k) - FP16: ~2.4 TFLOPS - BF16: ~2.3 TFLOPS 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- README.md | 11 ++- benchmark_all.py | 181 +++++++++++++++++++++++++++++------------------ 2 files changed, 117 insertions(+), 75 deletions(-) diff --git a/README.md b/README.md index af0ff91..d20db38 100644 --- a/README.md +++ b/README.md @@ -103,18 +103,17 @@ print(f"NVRTC Path: {gp.get_nvrtc_path()}") # Path to NVRTC DLL (if available) |---------|------|------|--------------| | **NumPy** (OpenBLAS) | ~0.8 TFLOPS | — | CPU only | | **cuBLAS** | ~21 TFLOPS | ~59 TFLOPS | CUDA Toolkit | -| **PyGPUkit** (Driver-Only) | 17.7 TFLOPS | 28.2 TFLOPS | GPU drivers only | -| **PyGPUkit** (Full) | 17.7 TFLOPS | 30.3 TFLOPS | GPU drivers + CUDA Toolkit | +| **PyGPUkit** | 16.7 TFLOPS | 29.7 TFLOPS | GPU drivers only | -> Driver-Only mode uses pre-compiled kernels. Full mode adds JIT compilation for custom kernels with slightly better TF32 optimization. +> Built-in matmul kernels are pre-compiled. Driver-Only and Full (JIT) modes have identical matmul performance. JIT is only needed for custom kernels. ### PyGPUkit Performance by Matrix Size | Matrix Size | FP32 | TF32 | FP16 | BF16 | |-------------|------|------|------|------| -| 2048×2048 | 4.0 TFLOPS | 4.0 TFLOPS | 2.0 TFLOPS | 2.0 TFLOPS | -| 4096×4096 | 8.1 TFLOPS | 8.2 TFLOPS | 2.2 TFLOPS | 2.2 TFLOPS | -| 8192×8192 | 12.7 TFLOPS | 13.0 TFLOPS | 2.3 TFLOPS | 2.2 TFLOPS | +| 2048×2048 | 9.6 TFLOPS | 13.2 TFLOPS | 2.4 TFLOPS | 2.4 TFLOPS | +| 4096×4096 | 14.7 TFLOPS | 22.8 TFLOPS | 2.4 TFLOPS | 2.3 TFLOPS | +| 8192×8192 | 16.7 TFLOPS | 29.7 TFLOPS | 2.3 TFLOPS | 2.3 TFLOPS | > **Note:** FP16/BF16 matmul uses simple kernels with FP32 accumulation. TensorCore optimization planned for future releases (see [Issue #60](https://github.com/m96-chan/PyGPUkit/issues/60)). diff --git a/benchmark_all.py b/benchmark_all.py index 0ef32c3..fb1daea 100644 --- a/benchmark_all.py +++ b/benchmark_all.py @@ -2,12 +2,22 @@ """ PyGPUkit Comprehensive Benchmark -Benchmarks all supported dtypes and runtime modes: -- FP32, TF32, FP16, BF16 -- Driver-Only mode vs Full (JIT) mode +Benchmarks all supported dtypes: +- FP32 (Ampere optimized kernel) +- TF32 v1 (WMMA TensorCore) +- TF32 v2 (PTX mma.sync TensorCore, optimized) +- FP16 (simple kernel, TensorCore planned) +- BF16 (simple kernel, TensorCore planned) + +Runtime Modes: +- Driver-Only: Uses pre-compiled kernels, no CUDA Toolkit needed +- Full (JIT): Same kernels + JIT compilation for custom ops + +Note: Built-in matmul kernels are pre-compiled, so Driver-Only and Full +modes have identical performance for matmul operations. Usage: - python benchmark_all.py [--sizes SIZES] [--quick] + python benchmark_all.py [--sizes SIZES] [--quick] [--tf32-version v1|v2] Output format matches README.md tables for easy updates. """ @@ -17,7 +27,6 @@ import sys import time from dataclasses import dataclass -from typing import Optional import numpy as np @@ -62,6 +71,7 @@ class GPUInfo: # ============================================================================= _native_module = None + def get_native_module(): """Get native module with fallback.""" global _native_module @@ -84,7 +94,6 @@ def get_gpu_info() -> GPUInfo: native = get_native_module() props = native.get_device_properties(0) - # Check NVRTC availability try: import pygpukit as gpk nvrtc = gpk.is_nvrtc_available() @@ -100,7 +109,7 @@ def get_gpu_info() -> GPUInfo: def benchmark_fp32(size: int, warmup: int = 5, iterations: int = 10) -> BenchmarkResult: - """Benchmark FP32 matmul.""" + """Benchmark FP32 matmul (Ampere optimized kernel).""" native = get_native_module() A = np.random.randn(size, size).astype(np.float32) @@ -114,7 +123,7 @@ def benchmark_fp32(size: int, warmup: int = 5, iterations: int = 10) -> Benchmar C_result = C_gpu.to_numpy() C_expected = A @ B rel_error = np.max(np.abs(C_result - C_expected)) / np.max(np.abs(C_expected)) - correct = rel_error < 1e-3 # FP32 matmul has some numerical error due to order of operations + correct = rel_error < 1e-3 # Warmup for _ in range(warmup): @@ -143,18 +152,30 @@ def benchmark_fp32(size: int, warmup: int = 5, iterations: int = 10) -> Benchmar ) -def benchmark_tf32(size: int, warmup: int = 5, iterations: int = 10) -> BenchmarkResult: - """Benchmark TF32 TensorCore matmul.""" +def benchmark_tf32(size: int, warmup: int = 5, iterations: int = 10, use_v2: bool = True) -> BenchmarkResult: + """Benchmark TF32 TensorCore matmul. + + Uses environment variables to control kernel selection: + - PYGPUKIT_ALLOW_TF32=1: Enable TF32 kernels + - PYGPUKIT_TF32_V2=1: Use optimized v2 kernel (PTX mma.sync) + """ native = get_native_module() + # Set environment for TF32 + os.environ["PYGPUKIT_ALLOW_TF32"] = "1" + if use_v2: + os.environ["PYGPUKIT_TF32_V2"] = "1" + else: + os.environ.pop("PYGPUKIT_TF32_V2", None) + A = np.random.randn(size, size).astype(np.float32) B = np.random.randn(size, size).astype(np.float32) A_gpu = native.from_numpy(A) B_gpu = native.from_numpy(B) - # Correctness (TF32 tolerance is higher) - C_gpu = native.matmul_tf32(A_gpu, B_gpu, True) + # Correctness - use native.matmul which respects env vars + C_gpu = native.matmul(A_gpu, B_gpu) C_result = C_gpu.to_numpy() C_expected = A @ B rel_error = np.max(np.abs(C_result - C_expected)) / np.max(np.abs(C_expected)) @@ -162,13 +183,13 @@ def benchmark_tf32(size: int, warmup: int = 5, iterations: int = 10) -> Benchmar # Warmup for _ in range(warmup): - _ = native.matmul_tf32(A_gpu, B_gpu, True) + _ = native.matmul(A_gpu, B_gpu) # Benchmark times = [] for _ in range(iterations): start = time.perf_counter() - _ = native.matmul_tf32(A_gpu, B_gpu, True) + _ = native.matmul(A_gpu, B_gpu) elapsed = time.perf_counter() - start times.append(elapsed) @@ -176,8 +197,9 @@ def benchmark_tf32(size: int, warmup: int = 5, iterations: int = 10) -> Benchmar min_time = np.min(times) flops = 2.0 * size * size * size + version = "v2" if use_v2 else "v1" return BenchmarkResult( - dtype="TF32", + dtype=f"TF32 {version}", size=size, tflops_median=flops / median_time / 1e12, tflops_max=flops / min_time / 1e12, @@ -188,7 +210,7 @@ def benchmark_tf32(size: int, warmup: int = 5, iterations: int = 10) -> Benchmar def benchmark_fp16(size: int, warmup: int = 5, iterations: int = 10) -> BenchmarkResult: - """Benchmark FP16 matmul.""" + """Benchmark FP16 matmul (simple kernel, no TensorCore yet).""" native = get_native_module() A = np.random.randn(size, size).astype(np.float16) @@ -201,7 +223,9 @@ def benchmark_fp16(size: int, warmup: int = 5, iterations: int = 10) -> Benchmar C_gpu = native.matmul(A_gpu, B_gpu) C_result = C_gpu.to_numpy() C_expected = (A.astype(np.float32) @ B.astype(np.float32)).astype(np.float16) - rel_error = np.max(np.abs(C_result.astype(np.float32) - C_expected.astype(np.float32))) / (np.max(np.abs(C_expected.astype(np.float32))) + 1e-7) + rel_error = np.max(np.abs(C_result.astype(np.float32) - C_expected.astype(np.float32))) / ( + np.max(np.abs(C_expected.astype(np.float32))) + 1e-7 + ) correct = rel_error < 0.05 # Warmup @@ -232,7 +256,7 @@ def benchmark_fp16(size: int, warmup: int = 5, iterations: int = 10) -> Benchmar def benchmark_bf16(size: int, warmup: int = 5, iterations: int = 10) -> BenchmarkResult: - """Benchmark BF16 matmul.""" + """Benchmark BF16 matmul (simple kernel, no TensorCore yet).""" native = get_native_module() import pygpukit as gpk @@ -245,14 +269,13 @@ def benchmark_bf16(size: int, warmup: int = 5, iterations: int = 10) -> Benchmar # Correctness C_gpu = native.matmul(A_gpu, B_gpu) - # Convert result back to FP32 for comparison C_gpk = gpk.GPUArray._wrap_native(C_gpu).astype(gpk.float32) C_result = C_gpk.to_numpy() C_expected = A_fp32 @ B_fp32 rel_error = np.max(np.abs(C_result - C_expected)) / (np.max(np.abs(C_expected)) + 1e-7) correct = rel_error < 0.05 - # Re-create arrays for benchmark (previous ones consumed) + # Re-create arrays for benchmark A_gpu = gpk.from_numpy(A_fp32).astype(gpk.bfloat16)._get_native() B_gpu = gpk.from_numpy(B_fp32).astype(gpk.bfloat16)._get_native() @@ -286,7 +309,7 @@ def benchmark_bf16(size: int, warmup: int = 5, iterations: int = 10) -> Benchmar # ============================================================================= # Output Functions # ============================================================================= -def print_header(gpu_info: GPUInfo): +def print_header(gpu_info: GPUInfo, tf32_version: str): """Print benchmark header.""" print("=" * 70) print(" PyGPUkit Comprehensive Benchmark") @@ -295,7 +318,10 @@ def print_header(gpu_info: GPUInfo): print(f"GPU: {gpu_info.name}") print(f"SM: {gpu_info.sm_major}.{gpu_info.sm_minor}") print(f"NVRTC (JIT): {'Available' if gpu_info.nvrtc_available else 'Not Available'}") - print(f"Mode: {'Full (Driver + JIT)' if gpu_info.nvrtc_available else 'Driver-Only'}") + print(f"TF32 Kernel: {tf32_version}") + print() + print("Note: Built-in matmul kernels are pre-compiled.") + print(" Driver-Only and Full modes have identical matmul performance.") print() @@ -305,17 +331,17 @@ def print_correctness_results(results: list): print(" Correctness Verification") print("=" * 70) print() - print(f"{'Dtype':<8} {'Size':<12} {'Rel Error':<12} {'Status':<8}") - print("-" * 44) + print(f"{'Dtype':<12} {'Size':<12} {'Rel Error':<12} {'Status':<8}") + print("-" * 48) for r in results: status = "PASS" if r.correct else "FAIL" - print(f"{r.dtype:<8} {r.size}x{r.size:<6} {r.rel_error:<12.2e} {status:<8}") + print(f"{r.dtype:<12} {r.size}x{r.size:<6} {r.rel_error:<12.2e} {status:<8}") print() def print_benchmark_results(results: list, sizes: list): - """Print benchmark results in README-compatible table format.""" + """Print benchmark results.""" print("=" * 70) print(" Performance Results (TFLOPS)") print("=" * 70) @@ -328,33 +354,40 @@ def print_benchmark_results(results: list, sizes: list): by_size[r.size] = {} by_size[r.size][r.dtype] = r - # Print table - print(f"{'Size':<14} {'FP32':<10} {'TF32':<10} {'FP16':<10} {'BF16':<10}") - print("-" * 54) + # Get all dtypes + all_dtypes = [] + for r in results: + if r.dtype not in all_dtypes: + all_dtypes.append(r.dtype) + + # Print header + header = f"{'Size':<14}" + for dt in all_dtypes: + header += f"{dt:<12}" + print(header) + print("-" * (14 + 12 * len(all_dtypes))) + # Print rows for size in sizes: if size not in by_size: continue row = by_size[size] - fp32 = row.get("FP32") - tf32 = row.get("TF32") - fp16 = row.get("FP16") - bf16 = row.get("BF16") - - fp32_str = f"{fp32.tflops_median:.1f}" if fp32 else "-" - tf32_str = f"{tf32.tflops_median:.1f}" if tf32 else "-" - fp16_str = f"{fp16.tflops_median:.1f}" if fp16 else "-" - bf16_str = f"{bf16.tflops_median:.1f}" if bf16 else "-" - - print(f"{size}x{size:<8} {fp32_str:<10} {tf32_str:<10} {fp16_str:<10} {bf16_str:<10}") + line = f"{size}x{size:<8}" + for dt in all_dtypes: + r = row.get(dt) + if r: + line += f"{r.tflops_median:<12.1f}" + else: + line += f"{'-':<12}" + print(line) print() -def print_readme_table(results: list, sizes: list, mode: str): +def print_readme_table(results: list, sizes: list): """Print README.md compatible markdown table.""" print("=" * 70) - print(f" README.md Table ({mode})") + print(" README.md Table") print("=" * 70) print() @@ -365,24 +398,33 @@ def print_readme_table(results: list, sizes: list, mode: str): by_size[r.size] = {} by_size[r.size][r.dtype] = r - print("| Matrix Size | FP32 | TF32 | FP16 | BF16 |") - print("|-------------|------|------|------|------|") + # Get dtypes + all_dtypes = [] + for r in results: + if r.dtype not in all_dtypes: + all_dtypes.append(r.dtype) + + # Print markdown table + header = "| Matrix Size |" + separator = "|-------------|" + for dt in all_dtypes: + header += f" {dt} |" + separator += "------|" + print(header) + print(separator) for size in sizes: if size not in by_size: continue row = by_size[size] - fp32 = row.get("FP32") - tf32 = row.get("TF32") - fp16 = row.get("FP16") - bf16 = row.get("BF16") - - fp32_str = f"{fp32.tflops_median:.1f} TFLOPS" if fp32 else "-" - tf32_str = f"{tf32.tflops_median:.1f} TFLOPS" if tf32 else "-" - fp16_str = f"{fp16.tflops_median:.1f} TFLOPS" if fp16 else "-" - bf16_str = f"{bf16.tflops_median:.1f} TFLOPS" if bf16 else "-" - - print(f"| {size}x{size} | {fp32_str} | {tf32_str} | {fp16_str} | {bf16_str} |") + line = f"| {size}x{size} |" + for dt in all_dtypes: + r = row.get(dt) + if r: + line += f" {r.tflops_median:.1f} TFLOPS |" + else: + line += " - |" + print(line) print() @@ -398,23 +440,20 @@ def main(): help="Quick mode: fewer iterations") parser.add_argument("--dtypes", type=str, default="fp32,tf32,fp16,bf16", help="Comma-separated dtypes to benchmark") + parser.add_argument("--tf32-version", type=str, default="v2", choices=["v1", "v2"], + help="TF32 kernel version: v1 (WMMA) or v2 (PTX mma.sync, default)") args = parser.parse_args() sizes = [int(s.strip()) for s in args.sizes.split(",")] dtypes = [d.strip().lower() for d in args.dtypes.split(",")] + use_tf32_v2 = args.tf32_version == "v2" warmup = 3 if args.quick else 5 iterations = 5 if args.quick else 10 - # Setup environment for TF32 - os.environ["PYGPUKIT_ALLOW_TF32"] = "1" - os.environ["PYGPUKIT_TF32_V2"] = "1" - # Get GPU info gpu_info = get_gpu_info() - print_header(gpu_info) - - mode = "Full (Driver + JIT)" if gpu_info.nvrtc_available else "Driver-Only" + print_header(gpu_info, args.tf32_version.upper()) # Run benchmarks results = [] @@ -423,17 +462,20 @@ def main(): print() for size in sizes: - iters = iterations // 2 if size >= 8192 else iterations + iters = max(2, iterations // 2) if size >= 8192 else iterations if "fp32" in dtypes: + # Disable TF32 for FP32 benchmark + os.environ.pop("PYGPUKIT_ALLOW_TF32", None) + os.environ.pop("PYGPUKIT_TF32_V2", None) print(f" FP32 {size}x{size}...", end=" ", flush=True) r = benchmark_fp32(size, warmup, iters) results.append(r) print(f"{r.tflops_median:.1f} TFLOPS") if "tf32" in dtypes: - print(f" TF32 {size}x{size}...", end=" ", flush=True) - r = benchmark_tf32(size, warmup, iters) + print(f" TF32 {args.tf32_version} {size}x{size}...", end=" ", flush=True) + r = benchmark_tf32(size, warmup, iters, use_v2=use_tf32_v2) results.append(r) print(f"{r.tflops_median:.1f} TFLOPS") @@ -454,17 +496,16 @@ def main(): # Print results print_correctness_results(results) print_benchmark_results(results, sizes) - print_readme_table(results, sizes, mode) + print_readme_table(results, sizes) # Summary print("=" * 70) print(" Summary") print("=" * 70) print() - print(f"Mode: {mode}") print(f"GPU: {gpu_info.name}") + print(f"TF32 Kernel: {args.tf32_version.upper()}") - # Find peak performance if results: peak = max(results, key=lambda r: r.tflops_median) print(f"Peak: {peak.tflops_median:.1f} TFLOPS ({peak.dtype}, {peak.size}x{peak.size})") @@ -473,7 +514,9 @@ def main(): print("RTX 3090 Ti Theoretical:") print(" FP32: ~40 TFLOPS") print(" TF32 TensorCore: ~80 TFLOPS (Sparse: ~156 TFLOPS)") - print(" FP16 TensorCore: ~160 TFLOPS") + print(" FP16 TensorCore: ~160 TFLOPS (not yet optimized)") + print() + print("Note: FP16/BF16 use simple kernels. TensorCore optimization in Issue #60.") print() From ececb65d55acc34e3be648ef6b133109d1004e29 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 14:53:29 +0900 Subject: [PATCH 19/24] refactor: consolidate benchmark files into benchmark.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Renamed benchmark_all.py → benchmark.py - Deleted redundant files: benchmark_tf32.py, benchmark_ampere.py, bench_tf32_v2.py - Kept: benchmark_rust.py (scheduler), benchmark_pytorch.py (cuBLAS comparison) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- bench_tf32_v2.py | 92 ------------------------- benchmark_all.py => benchmark.py | 0 benchmark_ampere.py | 99 --------------------------- benchmark_tf32.py | 111 ------------------------------- 4 files changed, 302 deletions(-) delete mode 100644 bench_tf32_v2.py rename benchmark_all.py => benchmark.py (100%) delete mode 100644 benchmark_ampere.py delete mode 100644 benchmark_tf32.py diff --git a/bench_tf32_v2.py b/bench_tf32_v2.py deleted file mode 100644 index 9264edf..0000000 --- a/bench_tf32_v2.py +++ /dev/null @@ -1,92 +0,0 @@ -"""TF32 v2 Kernel Benchmark""" -import os -import numpy as np -import time - -# Enable v2 kernel -os.environ["PYGPUKIT_TF32_V2"] = "1" - -def benchmark(): - import pygpukit as gk - - if not gk.is_cuda_available(): - print("CUDA not available") - return - - info = gk.get_device_info() - print(f"Device: {info.name}") - print(f"Using TF32 v2 kernel: PYGPUKIT_TF32_V2={os.environ.get('PYGPUKIT_TF32_V2', '0')}") - - sizes = [2048, 4096, 8192] - - print("\n" + "=" * 50) - print("Performance Benchmark (TF32 v2)") - print("=" * 50) - - for N in sizes: - M, K = N, N - - a_np = np.random.randn(M, K).astype(np.float32) - b_np = np.random.randn(K, N).astype(np.float32) - - a = gk.from_numpy(a_np) - b = gk.from_numpy(b_np) - - # Warmup - for _ in range(5): - c = gk.matmul(a, b, use_tf32=True) - - # Benchmark - num_iters = 20 - start = time.perf_counter() - for _ in range(num_iters): - c = gk.matmul(a, b, use_tf32=True) - elapsed = time.perf_counter() - start - - avg_time_ms = (elapsed / num_iters) * 1000 - flops = 2.0 * M * N * K - tflops = (flops / (avg_time_ms / 1000)) / 1e12 - - print(f"{N}x{N}x{N}: {avg_time_ms:.2f} ms, {tflops:.2f} TFLOPS") - - # Correctness check - print("\n" + "=" * 50) - print("Correctness Check") - print("=" * 50) - - all_pass = True - for N in [256, 512, 1024, 2048]: - a_np = np.random.randn(N, N).astype(np.float32) - b_np = np.random.randn(N, N).astype(np.float32) - - a = gk.from_numpy(a_np) - b = gk.from_numpy(b_np) - - c = gk.matmul(a, b, use_tf32=True) - c_np = c.to_numpy() - - expected = a_np @ b_np - - abs_error = np.abs(c_np - expected) - scale = np.maximum(np.abs(expected), np.abs(c_np)) - scale = np.maximum(scale, 1.0) - rel_error = abs_error / scale - max_rel_error = np.max(rel_error) - mean_rel_error = np.mean(rel_error) - p99_rel_error = np.percentile(rel_error, 99) - - # TF32 has 10 mantissa bits, allow up to 2% error for large matmuls - status = "PASS" if p99_rel_error < 2e-2 else "FAIL" - if status == "FAIL": - all_pass = False - print(f" {N}x{N}: max={max_rel_error:.6f}, mean={mean_rel_error:.6f}, p99={p99_rel_error:.6f} [{status}]") - - print("\n" + "=" * 50) - print(f"Overall: {'PASS' if all_pass else 'FAIL'}") - print("=" * 50) - -if __name__ == "__main__": - print("=" * 60) - print("TF32 v2 Kernel Benchmark") - print("=" * 60) - benchmark() diff --git a/benchmark_all.py b/benchmark.py similarity index 100% rename from benchmark_all.py rename to benchmark.py diff --git a/benchmark_ampere.py b/benchmark_ampere.py deleted file mode 100644 index 322f04b..0000000 --- a/benchmark_ampere.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Benchmark Ampere-optimized GEMM kernel.""" -import os -import time - -import numpy as np - -# Setup CUDA DLL path (if CUDA is installed) -cuda_path = os.environ.get( - "CUDA_PATH", r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" -) -cuda_bin = os.path.join(cuda_path, "bin") -if os.path.isdir(cuda_bin): - if cuda_bin not in os.environ.get("PATH", ""): - os.environ["PATH"] = cuda_bin + os.pathsep + os.environ.get("PATH", "") - if hasattr(os, "add_dll_directory"): - os.add_dll_directory(cuda_bin) - -# Import native module -try: - import _pygpukit_native as native -except ImportError: - from pygpukit import _pygpukit_native as native - -props = native.get_device_properties(0) -print(f"GPU: {props.name}") -print() - - -def verify_correctness(m, n, k): - """Verify kernel correctness.""" - A = np.random.randn(m, k).astype(np.float32) - B = np.random.randn(k, n).astype(np.float32) - - A_gpu = native.from_numpy(A) - B_gpu = native.from_numpy(B) - C_gpu = native.matmul(A_gpu, B_gpu) - C_result = C_gpu.to_numpy() - - C_expected = A @ B - rel_error = np.max(np.abs(C_result - C_expected)) / np.max(np.abs(C_expected)) - return rel_error - - -def benchmark_matmul(m, n, k, warmup=3, iterations=10): - """Benchmark matmul and return median time and TFLOPS.""" - A_np = np.random.randn(m, k).astype(np.float32) - B_np = np.random.randn(k, n).astype(np.float32) - - # Pre-allocate GPU arrays - A_gpu = native.from_numpy(A_np) - B_gpu = native.from_numpy(B_np) - - # Warmup - for _ in range(warmup): - _ = native.matmul(A_gpu, B_gpu) - - # Benchmark (reuse same input arrays) - times = [] - for _ in range(iterations): - start = time.perf_counter() - _ = native.matmul(A_gpu, B_gpu) - elapsed = time.perf_counter() - start - times.append(elapsed) - - median_time = np.median(times) - min_time = np.min(times) - flops = 2 * m * n * k - tflops_median = flops / median_time / 1e12 - tflops_max = flops / min_time / 1e12 - return median_time, tflops_median, tflops_max - - -# First verify correctness -print("=== Correctness Verification ===") -for size in [256, 512, 1024, 2048, 4096]: - error = verify_correctness(size, size, size) - status = "PASS" if error < 1e-4 else "FAIL" - print(f"{size}x{size}: relative error = {error:.2e} [{status}]") - -print() - -# Benchmark different sizes -sizes = [ - (2048, 2048, 2048), - (4096, 4096, 4096), - (8192, 8192, 8192), -] - -print("=== Ampere-Optimized GEMM Benchmark ===") -print() -for m, n, k in sizes: - iters = 5 if m >= 8192 else 10 - time_ms, tflops_med, tflops_max = benchmark_matmul(m, n, k, warmup=5, iterations=iters) - status = "PASS" if tflops_med >= 22.0 else "FAIL" - print(f"{m}x{n}x{k}: {tflops_med:.1f} TFLOPS (max: {tflops_max:.1f}) - {time_ms*1000:.2f} ms [{status}]") - -print() -print("Target: 22-32 TFLOPS (62-90% efficiency on RTX 3090 Ti)") -print("Minimum: 22 TFLOPS to beat PyTorch baseline") diff --git a/benchmark_tf32.py b/benchmark_tf32.py deleted file mode 100644 index 9f9ba92..0000000 --- a/benchmark_tf32.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Benchmark TF32 TensorCore GEMM kernel.""" -import os -import time - -import numpy as np - -# Setup CUDA DLL path (if CUDA is installed) -cuda_path = os.environ.get( - "CUDA_PATH", r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" -) -cuda_bin = os.path.join(cuda_path, "bin") -if os.path.isdir(cuda_bin): - if cuda_bin not in os.environ.get("PATH", ""): - os.environ["PATH"] = cuda_bin + os.pathsep + os.environ.get("PATH", "") - if hasattr(os, "add_dll_directory"): - os.add_dll_directory(cuda_bin) - -# Import native module -try: - import _pygpukit_native as native -except ImportError: - from pygpukit import _pygpukit_native as native - -props = native.get_device_properties(0) -print(f"GPU: {props.name}") -print(f"SM: {props.compute_capability_major}.{props.compute_capability_minor}") -print() - - -def verify_correctness(m, n, k, tolerance=1e-2): - """Verify kernel correctness with TF32 tolerance.""" - A = np.random.randn(m, k).astype(np.float32) - B = np.random.randn(k, n).astype(np.float32) - - A_gpu = native.from_numpy(A) - B_gpu = native.from_numpy(B) - C_gpu = native.matmul(A_gpu, B_gpu) - C_result = C_gpu.to_numpy() - - C_expected = A @ B - rel_error = np.max(np.abs(C_result - C_expected)) / np.max(np.abs(C_expected)) - return rel_error - - -def benchmark_matmul(m, n, k, warmup=5, iterations=10): - """Benchmark matmul and return median time and TFLOPS.""" - A_np = np.random.randn(m, k).astype(np.float32) - B_np = np.random.randn(k, n).astype(np.float32) - - A_gpu = native.from_numpy(A_np) - B_gpu = native.from_numpy(B_np) - - # Warmup - for _ in range(warmup): - _ = native.matmul(A_gpu, B_gpu) - - # Benchmark - times = [] - for _ in range(iterations): - start = time.perf_counter() - _ = native.matmul(A_gpu, B_gpu) - elapsed = time.perf_counter() - start - times.append(elapsed) - - median_time = np.median(times) - min_time = np.min(times) - flops = 2 * m * n * k - tflops_median = flops / median_time / 1e12 - tflops_max = flops / min_time / 1e12 - return median_time, tflops_median, tflops_max - - -# Correctness verification -print("=== Correctness Verification (TF32 tolerance: 1e-2) ===") -for size in [256, 512, 1024, 2048, 4096]: - error = verify_correctness(size, size, size) - status = "PASS" if error < 1e-2 else "FAIL" - print(f"{size}x{size}: relative error = {error:.2e} [{status}]") - -print() - -# Performance benchmark -sizes = [ - (2048, 2048, 2048), - (4096, 4096, 4096), - (8192, 8192, 8192), -] - -print("=== TF32 TensorCore GEMM Benchmark ===") -print() - -# Performance targets -TARGETS = { - 2048: 15.0, - 4096: 22.0, - 8192: 28.0, -} - -for m, n, k in sizes: - iters = 5 if m >= 8192 else 10 - time_ms, tflops_med, tflops_max = benchmark_matmul(m, n, k, warmup=5, iterations=iters) - target = TARGETS.get(m, 20.0) - status = "PASS" if tflops_med >= target else "FAIL" - print(f"{m}x{n}x{k}: {tflops_med:.1f} TFLOPS (max: {tflops_max:.1f}) - {time_ms*1000:.2f} ms [{status}]") - -print() -print("=== Performance Targets ===") -print("4096x4096: 22 TFLOPS minimum, 30 TFLOPS target") -print("8192x8192: 28 TFLOPS minimum, 35 TFLOPS target") -print() -print("RTX 3090 Ti theoretical: 40 TFLOPS (FP32), 156 TFLOPS (TF32)") From 82afa7e30f4488382580457c2e361443454c4860 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 14:56:43 +0900 Subject: [PATCH 20/24] docs: update for v0.2.5 release MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - CLAUDE.md: add benchmark.py usage instructions - README.md: update v0.2.5 highlights with TF32 v2 (~30 TFLOPS) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- CLAUDE.md | 28 ++++++++++++++++++++++++++++ README.md | 2 +- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/CLAUDE.md b/CLAUDE.md index 5536d50..46ef530 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -503,6 +503,34 @@ If performance or correctness degrades: - Track performance changes over time - Preserve trial-and-error history +### Benchmarking + +**Always use `benchmark.py` for performance measurement.** + +```bash +# Full benchmark (all dtypes, all sizes) +python benchmark.py + +# Quick mode (fewer warmup/iterations) +python benchmark.py --quick + +# Specific sizes +python benchmark.py --sizes 4096 8192 + +# TF32 kernel version selection +python benchmark.py --tf32-version v1 # WMMA API +python benchmark.py --tf32-version v2 # PTX mma.sync (default) +``` + +**Output includes:** +- Kernel-only timing (no D2H copy overhead) +- Correctness verification (relative error) +- README.md-ready table format + +**Environment Variables:** +- `PYGPUKIT_ALLOW_TF32=1` - Enable TF32 TensorCore +- `PYGPUKIT_TF32_V2=1` - Use PTX mma.sync kernel (default when TF32 enabled) + --- ## Design Principles diff --git a/README.md b/README.md index d20db38..556d202 100644 --- a/README.md +++ b/README.md @@ -264,7 +264,7 @@ PyGPUkit/ | **v0.2.2** | Ampere SGEMM (cp.async, float4), 18 TFLOPS FP32 | | **v0.2.3** | TF32 TensorCore (PTX mma.sync), 28 TFLOPS | | **v0.2.4** | **Single-binary distribution**, dynamic NVRTC, driver-only mode | -| **v0.2.5** | **FP16/BF16 support**, reduction ops (sum, mean, max), operator overloads | +| **v0.2.5** | **FP16/BF16 support**, reduction ops, operator overloads, TF32 v2 (~30 TFLOPS) | ### Planned From 4ba0716dbfa8ddfcb0329b69bcfdf4642b755090 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 15:01:04 +0900 Subject: [PATCH 21/24] fix: lint errors in benchmark.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove unused sys import - Fix bare except clause 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- benchmark.py | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/benchmark.py b/benchmark.py index fb1daea..fb63003 100644 --- a/benchmark.py +++ b/benchmark.py @@ -24,7 +24,6 @@ import argparse import os -import sys import time from dataclasses import dataclass @@ -33,9 +32,7 @@ # ============================================================================= # Setup CUDA DLL path (Windows) # ============================================================================= -cuda_path = os.environ.get( - "CUDA_PATH", r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" -) +cuda_path = os.environ.get("CUDA_PATH", r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4") cuda_bin = os.path.join(cuda_path, "bin") if os.path.isdir(cuda_bin): if cuda_bin not in os.environ.get("PATH", ""): @@ -79,9 +76,11 @@ def get_native_module(): return _native_module try: import _pygpukit_native as native + _native_module = native except ImportError: from pygpukit import _pygpukit_native as native + _native_module = native return _native_module @@ -96,8 +95,9 @@ def get_gpu_info() -> GPUInfo: try: import pygpukit as gpk + nvrtc = gpk.is_nvrtc_available() - except: + except Exception: nvrtc = False return GPUInfo( @@ -152,7 +152,9 @@ def benchmark_fp32(size: int, warmup: int = 5, iterations: int = 10) -> Benchmar ) -def benchmark_tf32(size: int, warmup: int = 5, iterations: int = 10, use_v2: bool = True) -> BenchmarkResult: +def benchmark_tf32( + size: int, warmup: int = 5, iterations: int = 10, use_v2: bool = True +) -> BenchmarkResult: """Benchmark TF32 TensorCore matmul. Uses environment variables to control kernel selection: @@ -434,14 +436,26 @@ def print_readme_table(results: list, sizes: list): # ============================================================================= def main(): parser = argparse.ArgumentParser(description="PyGPUkit Comprehensive Benchmark") - parser.add_argument("--sizes", type=str, default="2048,4096,8192", - help="Comma-separated matrix sizes (default: 2048,4096,8192)") - parser.add_argument("--quick", action="store_true", - help="Quick mode: fewer iterations") - parser.add_argument("--dtypes", type=str, default="fp32,tf32,fp16,bf16", - help="Comma-separated dtypes to benchmark") - parser.add_argument("--tf32-version", type=str, default="v2", choices=["v1", "v2"], - help="TF32 kernel version: v1 (WMMA) or v2 (PTX mma.sync, default)") + parser.add_argument( + "--sizes", + type=str, + default="2048,4096,8192", + help="Comma-separated matrix sizes (default: 2048,4096,8192)", + ) + parser.add_argument("--quick", action="store_true", help="Quick mode: fewer iterations") + parser.add_argument( + "--dtypes", + type=str, + default="fp32,tf32,fp16,bf16", + help="Comma-separated dtypes to benchmark", + ) + parser.add_argument( + "--tf32-version", + type=str, + default="v2", + choices=["v1", "v2"], + help="TF32 kernel version: v1 (WMMA) or v2 (PTX mma.sync, default)", + ) args = parser.parse_args() sizes = [int(s.strip()) for s in args.sizes.split(",")] From 07d61628c123934872026718316a290a6889ef94 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 15:04:31 +0900 Subject: [PATCH 22/24] fix: lint errors across codebase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix import sorting (I001) in multiple files - Remove unused imports (F401) - Remove unused f-string prefixes (F541) - Add per-file-ignores for examples/ and compiler.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- benchmark_pytorch.py | 4 +++- examples/demo_runtime_modes.py | 2 -- examples/demo_v02.py | 20 ++++++++++---------- examples/demo_v023.py | 6 +++--- examples/demo_v025.py | 7 ++++--- examples/demo_v02_full.py | 26 +++++++++++++------------- pyproject.toml | 4 ++++ src/pygpukit/__init__.py | 2 +- src/pygpukit/core/array.py | 4 ++-- src/pygpukit/core/factory.py | 2 +- src/pygpukit/jit/compiler.py | 2 -- src/pygpukit/ops/basic.py | 2 +- 12 files changed, 42 insertions(+), 39 deletions(-) diff --git a/benchmark_pytorch.py b/benchmark_pytorch.py index af788e3..da35962 100644 --- a/benchmark_pytorch.py +++ b/benchmark_pytorch.py @@ -1,5 +1,7 @@ """Benchmark PyTorch cuBLAS for comparison with PyGPUkit.""" + import time + import numpy as np try: @@ -37,4 +39,4 @@ median_time = np.median(times) tflops = 2 * size**3 / median_time / 1e12 - print(f"{size}x{size}: {tflops:.1f} TFLOPS ({median_time*1000:.2f} ms)") + print(f"{size}x{size}: {tflops:.1f} TFLOPS ({median_time * 1000:.2f} ms)") diff --git a/examples/demo_runtime_modes.py b/examples/demo_runtime_modes.py index 0ae430d..b8d9d95 100644 --- a/examples/demo_runtime_modes.py +++ b/examples/demo_runtime_modes.py @@ -9,7 +9,6 @@ Run this script to see which mode your system supports. """ -import sys def print_header(title: str) -> None: @@ -164,7 +163,6 @@ def demo_cpu_simulation_mode() -> bool: # Demo: Operations work via NumPy print("\n [Demo] CPU-Simulated Operations:") - import numpy as np # Create arrays (backed by NumPy in simulation mode) A = gp.zeros((128, 128), dtype="float32") diff --git a/examples/demo_v02.py b/examples/demo_v02.py index 896e097..1f5ffca 100644 --- a/examples/demo_v02.py +++ b/examples/demo_v02.py @@ -79,7 +79,7 @@ def main(): quota=100 * 1024 * 1024, # 100 MB quota enable_eviction=True ) - print(f"Created pool with 100 MB quota, eviction enabled") + print("Created pool with 100 MB quota, eviction enabled") # Allocate blocks block_ids = [] @@ -90,7 +90,7 @@ def main(): print(f" Allocated block {block.id}: {block.size} bytes") stats = pool.stats() - print(f"\nPool stats:") + print("\nPool stats:") print(f" Active: {stats.active_blocks} blocks, {stats.used} bytes") print(f" Allocations: {stats.allocation_count}") print(f" Quota usage: {stats.used / stats.quota:.1%}") @@ -147,7 +147,7 @@ def main(): print(f"\nCompleted: {runnable_ids[0]}") sched_stats = scheduler.stats() - print(f"\nScheduler stats:") + print("\nScheduler stats:") print(f" Total submitted: {sched_stats.total_submitted}") print(f" Completed: {sched_stats.completed_count}") print(f" Pending: {sched_stats.pending_count}") @@ -188,7 +188,7 @@ def main(): print(f" Completed transfer {op.id}") transfer_stats = transfer_engine.stats() - print(f"\nTransfer stats:") + print("\nTransfer stats:") print(f" Total queued: {transfer_stats.total_queued}") print(f" Completed: {transfer_stats.completed_count}") print(f" Pending: {transfer_stats.pending_count}") @@ -228,7 +228,7 @@ def main(): dispatcher.mark_completed(req.id) dispatch_stats = dispatcher.stats() - print(f"\nDispatch stats:") + print("\nDispatch stats:") print(f" Total queued: {dispatch_stats.total_queued}") print(f" Completed: {dispatch_stats.completed_count}") print(f" Pending: {dispatch_stats.pending_count}") @@ -342,10 +342,10 @@ def main(): # Allocate memory input_block_id = pool.allocate(batch_size * hidden_dim * 4) weight_block_id = pool.allocate(hidden_dim * hidden_dim * 4) - output_block_id = pool.allocate(batch_size * hidden_dim * 4) + pool.allocate(batch_size * hidden_dim * 4) # Queue transfer - h2d_id = transfer_engine.enqueue_h2d( + transfer_engine.enqueue_h2d( host_ptr=0x1000, device_ptr=input_block_id, size=batch_size * hidden_dim * 4 @@ -353,7 +353,7 @@ def main(): # Queue kernel config = rust.LaunchConfig.linear(batch_size * hidden_dim, 256) - kernel_id = dispatcher.queue( + dispatcher.queue( kernel_handle=0xFFFF0000 + layer, config=config, task_id=task_id, @@ -368,7 +368,7 @@ def main(): W_gpu = native.from_numpy(W) start = time.perf_counter() - out_gpu = native.matmul(A_gpu, W_gpu) + native.matmul(A_gpu, W_gpu) layer_time = time.perf_counter() - start total_time += layer_time @@ -383,7 +383,7 @@ def main(): throughput = total_flops / total_time / 1e9 - print(f"\nPipeline completed:") + print("\nPipeline completed:") print(f" Total time: {total_time*1000:.2f} ms") print(f" Throughput: {throughput:.1f} GFLOPS") print(f" Tasks completed: {scheduler.stats().completed_count}") diff --git a/examples/demo_v023.py b/examples/demo_v023.py index bb17bec..b62b675 100644 --- a/examples/demo_v023.py +++ b/examples/demo_v023.py @@ -69,7 +69,7 @@ def main(): # Get device capabilities (v0.2.3 feature) caps = gp.get_device_capabilities() - print(f"\nDevice Capabilities:") + print("\nDevice Capabilities:") print(f" SM Version: {caps.sm_version}") print(f" TensorCore (TF32): {caps.tensorcore}") print(f" TensorCore (FP16): {caps.tensorcore_fp16}") @@ -116,14 +116,14 @@ def main(): # FP32 error fp32_abs_err = np.max(np.abs(result_fp32 - expected)) fp32_rel_err = np.max(np.abs(result_fp32 - expected) / (np.abs(expected) + 1e-8)) - print(f"\nFP32 Error:") + print("\nFP32 Error:") print(f" Max absolute error: {fp32_abs_err:.6e}") print(f" Max relative error: {fp32_rel_err:.6e} ({fp32_rel_err*100:.4f}%)") # TF32 error (expected to be higher due to reduced precision) tf32_abs_err = np.max(np.abs(result_tf32 - expected)) tf32_rel_err = np.max(np.abs(result_tf32 - expected) / (np.abs(expected) + 1e-8)) - print(f"\nTF32 Error:") + print("\nTF32 Error:") print(f" Max absolute error: {tf32_abs_err:.6e}") print(f" Max relative error: {tf32_rel_err:.6e} ({tf32_rel_err*100:.4f}%)") diff --git a/examples/demo_v025.py b/examples/demo_v025.py index 7d73f59..0401b10 100644 --- a/examples/demo_v025.py +++ b/examples/demo_v025.py @@ -10,9 +10,10 @@ - Type conversion: astype() """ -import numpy as np -import time import os +import time + +import numpy as np # Set TF32 environment before import os.environ["PYGPUKIT_ALLOW_TF32"] = "1" @@ -292,7 +293,7 @@ def main(): print("=" * 60) # Show version and backend info - print(f"\nBackend: Native C++/CUDA") + print("\nBackend: Native C++/CUDA") print(f"TF32 enabled: {os.environ.get('PYGPUKIT_ALLOW_TF32', '0') == '1'}") demo_dtypes() diff --git a/examples/demo_v02_full.py b/examples/demo_v02_full.py index 73b6273..d76330b 100644 --- a/examples/demo_v02_full.py +++ b/examples/demo_v02_full.py @@ -91,7 +91,7 @@ def main(): quota=100 * 1024 * 1024, # 100 MB quota enable_eviction=True ) - print(f"Created pool with 100 MB quota, eviction enabled") + print("Created pool with 100 MB quota, eviction enabled") # Allocate blocks block_ids = [] @@ -102,7 +102,7 @@ def main(): print(f" Allocated block {block.id}: {block.size} bytes") stats = pool.stats() - print(f"\nPool stats:") + print("\nPool stats:") print(f" Active: {stats.active_blocks} blocks, {stats.used} bytes") print(f" Allocations: {stats.allocation_count}") print(f" Quota usage: {stats.used / stats.quota:.1%}") @@ -240,7 +240,7 @@ def main(): print(f" Admitted task {i}: {task_id}") admission_stats = admission_scheduler.stats() - print(f"\nAdmission results:") + print("\nAdmission results:") print(f" Total submitted: {admission_stats.total_submitted}") print(f" Reserved memory: {admission_stats.reserved_memory / 1024 / 1024:.0f} MB") @@ -277,7 +277,7 @@ def main(): print(f" {class_name:12} | {task.name:15} | QUEUED") qos_stats = qos_evaluator.stats() - print(f"\nQoS stats:") + print("\nQoS stats:") print(f" Guaranteed memory: {qos_stats.guaranteed_memory / 1024 / 1024:.0f} MB") print(f" Burstable memory: {qos_stats.burstable_memory / 1024 / 1024:.0f} MB") print(f" Available memory: {qos_stats.available_memory / 1024 / 1024:.0f} MB") @@ -300,7 +300,7 @@ def main(): # Allocate bandwidth to streams pacing_engine.allocate_stream(0, 0.6) # 60% to stream 0 pacing_engine.allocate_stream(1, 0.3) # 30% to stream 1 - print(f"\nAllocated bandwidth: stream 0=60%, stream 1=30%") + print("\nAllocated bandwidth: stream 0=60%, stream 1=30%") # Test launch decisions for stream_id in [0, 1, 2]: # 2 is unknown @@ -315,7 +315,7 @@ def main(): print(f" Stream {stream_id}: WAIT {decision.wait_ms():.2f}ms") pacing_stats = pacing_engine.stats() - print(f"\nPacing stats:") + print("\nPacing stats:") print(f" Streams: {pacing_stats.stream_count}") print(f" Used bandwidth: {pacing_stats.used_bandwidth:.1%}") print(f" Total launches: {pacing_stats.total_launches}") @@ -366,7 +366,7 @@ def main(): executed += 1 slice_stats = slice_scheduler.stats() - print(f"\nSlice stats:") + print("\nSlice stats:") print(f" Total slices: {slice_stats.total_slices}") print(f" Completed: {slice_stats.completed_slices}") print(f" Pending: {slice_stats.pending_slices}") @@ -407,7 +407,7 @@ def main(): print(f"Re-allocated 65KB: reused={result2[2]}") pinned_stats = pinned_manager.stats() - print(f"\nPinned stats:") + print("\nPinned stats:") print(f" Current used: {pinned_stats.current_used} bytes") print(f" Pool hits: {pinned_stats.pool_hits}") print(f" Pool misses: {pinned_stats.pool_misses}") @@ -458,7 +458,7 @@ def main(): kernel_cache.set_handles(key, 0xAABB0000 + i, 0xCCDD0000 + i) cache_stats = kernel_cache.stats() - print(f"\nCache stats:") + print("\nCache stats:") print(f" Entries: {cache_stats.entries}") print(f" Hits: {cache_stats.hits}") print(f" Misses: {cache_stats.misses}") @@ -492,7 +492,7 @@ def main(): # Assign tasks partition_manager.assign_task("inference-task-1", "inference") partition_manager.assign_task("training-task-1", "training") - print(f"\nAssigned tasks to partitions") + print("\nAssigned tasks to partitions") # Check partition for task for task_id in ["inference-task-1", "training-task-1", "unknown-task"]: @@ -503,7 +503,7 @@ def main(): print(f" {task_id} -> (no partition)") partition_stats = partition_manager.stats() - print(f"\nPartition stats:") + print("\nPartition stats:") print(f" Partitions: {partition_stats.partition_count}") print(f" Memory allocated: {partition_stats.total_memory_allocated / 1024**3:.1f} GB") print(f" Compute allocated: {partition_stats.total_compute_allocated:.0%}") @@ -603,8 +603,8 @@ def main(): """) # Count tests - print(f"Total Rust tests: 106 passing") - print(f"Features demonstrated: 12") + print("Total Rust tests: 106 passing") + print("Features demonstrated: 12") print("\n" + "=" * 70) print(" PyGPUkit v0.2 Demo Complete!") diff --git a/pyproject.toml b/pyproject.toml index 6262db6..83c6401 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,6 +110,10 @@ src = ["src", "tests"] select = ["E", "F", "W", "I", "B", "C4", "UP"] ignore = ["E501"] +[tool.ruff.lint.per-file-ignores] +"examples/*" = ["E402", "B007", "F841"] +"src/pygpukit/jit/compiler.py" = ["E402", "F841"] + [tool.mypy] python_version = "3.9" warn_return_any = true diff --git a/src/pygpukit/__init__.py b/src/pygpukit/__init__.py index b49e4f1..d70821d 100644 --- a/src/pygpukit/__init__.py +++ b/src/pygpukit/__init__.py @@ -10,7 +10,7 @@ get_device_info, is_cuda_available, ) -from pygpukit.core.dtypes import DataType, float32, float64, float16, bfloat16, int32, int64 +from pygpukit.core.dtypes import DataType, bfloat16, float16, float32, float64, int32, int64 from pygpukit.core.factory import empty, from_numpy, ones, zeros from pygpukit.core.stream import Stream, StreamManager, default_stream from pygpukit.jit.compiler import ( diff --git a/src/pygpukit/core/array.py b/src/pygpukit/core/array.py index c3cfd97..a5c69b3 100644 --- a/src/pygpukit/core/array.py +++ b/src/pygpukit/core/array.py @@ -62,7 +62,7 @@ def _wrap_native(cls, native_array: Any) -> GPUArray: This is the fast path for GPU operations - no data copying. """ from pygpukit.core.backend import get_native_module - from pygpukit.core.dtypes import float32, float64, float16, bfloat16, int32, int64 + from pygpukit.core.dtypes import bfloat16, float16, float32, float64, int32, int64 native = get_native_module() @@ -278,8 +278,8 @@ def astype(self, dtype: DataType) -> GPUArray: if self._dtype == dtype: return self + from pygpukit.core.dtypes import bfloat16, float16, float32 from pygpukit.core.factory import from_numpy - from pygpukit.core.dtypes import bfloat16, float32, float16 # Get numpy array np_data = self.to_numpy() diff --git a/src/pygpukit/core/factory.py b/src/pygpukit/core/factory.py index 4f9a4d5..b067e71 100644 --- a/src/pygpukit/core/factory.py +++ b/src/pygpukit/core/factory.py @@ -205,7 +205,7 @@ def _from_numpy_native(array: np.ndarray) -> GPUArray: def _to_native_dtype(dtype: DataType, native: Any) -> Any: """Convert Python DataType to native DataType.""" - from pygpukit.core.dtypes import float32, float64, float16, bfloat16, int32, int64 + from pygpukit.core.dtypes import bfloat16, float16, float32, float64, int32, int64 if dtype == float32: return native.DataType.Float32 diff --git a/src/pygpukit/jit/compiler.py b/src/pygpukit/jit/compiler.py index 170f52f..b580caf 100644 --- a/src/pygpukit/jit/compiler.py +++ b/src/pygpukit/jit/compiler.py @@ -386,7 +386,6 @@ def _compile_native(self) -> None: # Try compilation with fallback on PTX load failure fallback_archs = self._get_fallback_archs(native) last_error: Exception | None = None - arch_used: str | None = None for arch_attempt, arch in enumerate(fallback_archs): current_options = self._replace_arch_option(options, arch) @@ -399,7 +398,6 @@ def _compile_native(self) -> None: ) self._ptx = self._kernel.ptx self._is_compiled = self._kernel.is_compiled - arch_used = arch # Warn if fallback was used if arch_attempt > 0: diff --git a/src/pygpukit/ops/basic.py b/src/pygpukit/ops/basic.py index a8f47de..356de0b 100644 --- a/src/pygpukit/ops/basic.py +++ b/src/pygpukit/ops/basic.py @@ -23,7 +23,7 @@ def _validate_same_dtype(a: GPUArray, b: GPUArray, op_name: str) -> None: def _validate_float_dtype(a: GPUArray, op_name: str) -> None: """Validate that array has float dtype.""" - from pygpukit.core.dtypes import float32, float64, float16, bfloat16 + from pygpukit.core.dtypes import bfloat16, float16, float32, float64 if a.dtype not in (float32, float64, float16, bfloat16): raise ValueError(f"{op_name} requires float dtype, got {a.dtype}") From 739f18a5eed6b61cb5e40e51261b432ab1a2ff3e Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 15:06:02 +0900 Subject: [PATCH 23/24] docs: add mandatory lint check rule to CLAUDE.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- CLAUDE.md | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 46ef530..c9b612e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -465,12 +465,28 @@ Edit → Build → Validate → Benchmark → Commit **Always commit after validation and benchmark, regardless of results.** +### Lint Check (MANDATORY) + +**Before EVERY commit, run lint check:** + +```bash +# Check all tracked Python files +git ls-files "*.py" | xargs python -m ruff check + +# Auto-fix and format +git ls-files "*.py" | xargs python -m ruff check --fix +git ls-files "*.py" | xargs python -m ruff format +``` + +**NEVER commit without passing lint.** CI will reject PRs with lint errors. + ### Commit Rules -1. Commit after every validation/benchmark completion, regardless of outcome -2. Include benchmark results in commit message -3. Never proceed to next kernel edit until commit is complete -4. Never overwrite a working kernel without committing first +1. **Run lint check before commit** (see above) +2. Commit after every validation/benchmark completion, regardless of outcome +3. Include benchmark results in commit message +4. Never proceed to next kernel edit until commit is complete +5. Never overwrite a working kernel without committing first ### Commit Message Format From ec2a40e07d89f7d360b45dab107a3909d20d2e51 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Mon, 15 Dec 2025 15:08:54 +0900 Subject: [PATCH 24/24] fix: lint/mypy errors and add PR checklist to CLAUDE.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix mypy error: add type annotation for 'converted' variable - Format all tracked Python files with ruff - Add comprehensive PR checklist (lint, mypy, tests, benchmark) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- CLAUDE.md | 34 ++++++-- benchmark_rust.py | 20 ++--- demo_scheduler_log.py | 42 +++++----- examples/benchmark_large.py | 2 +- examples/benchmark_tiled_matmul.py | 4 +- examples/demo_gpu.py | 26 +++---- examples/demo_optimized.py | 26 +++---- examples/demo_runtime_modes.py | 9 +-- examples/demo_v01.py | 4 +- examples/demo_v02.py | 64 ++++++++------- examples/demo_v023.py | 17 ++-- examples/demo_v025.py | 24 +++--- examples/demo_v02_full.py | 121 ++++++++++++++++------------- examples/scheduler_simulation.py | 34 ++++---- src/pygpukit/core/array.py | 7 +- src/pygpukit/core/device.py | 2 + src/pygpukit/jit/compiler.py | 15 ++-- src/pygpukit/ops/basic.py | 2 + tests/stress_test.py | 34 ++++---- tests/test_3090ti_performance.py | 13 ++-- tests/test_rust_admission_qos.py | 16 ++-- tests/test_tf32_api.py | 14 ++-- tests/test_tf32_tensorcore.py | 13 ++-- 23 files changed, 304 insertions(+), 239 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index c9b612e..29b3fd2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -465,20 +465,40 @@ Edit → Build → Validate → Benchmark → Commit **Always commit after validation and benchmark, regardless of results.** -### Lint Check (MANDATORY) +### Pre-Commit Checks (MANDATORY) -**Before EVERY commit, run lint check:** +**Before EVERY commit, run these checks:** ```bash -# Check all tracked Python files -git ls-files "*.py" | xargs python -m ruff check - -# Auto-fix and format +# 1. Ruff lint check (auto-fix and format) git ls-files "*.py" | xargs python -m ruff check --fix git ls-files "*.py" | xargs python -m ruff format + +# 2. Mypy type check +python -m mypy src/ --ignore-missing-imports --disable-error-code=union-attr --disable-error-code=no-redef --disable-error-code=no-any-return --disable-error-code=attr-defined +``` + +**NEVER commit without passing ALL checks.** CI will reject PRs with lint/type errors. + +### PR Checklist (MANDATORY before `gh pr create`) + +Before creating a PR, verify ALL of the following: + +```bash +# 1. Lint passes +git ls-files "*.py" | xargs python -m ruff check + +# 2. Mypy passes +python -m mypy src/ --ignore-missing-imports --disable-error-code=union-attr --disable-error-code=no-redef --disable-error-code=no-any-return --disable-error-code=attr-defined + +# 3. Tests pass +python -m pytest tests/ -v + +# 4. Benchmark runs (optional but recommended) +python benchmark.py --quick ``` -**NEVER commit without passing lint.** CI will reject PRs with lint errors. +**DO NOT create PR until all checks pass locally.** ### Commit Rules diff --git a/benchmark_rust.py b/benchmark_rust.py index f067d33..b3bff5f 100644 --- a/benchmark_rust.py +++ b/benchmark_rust.py @@ -25,7 +25,7 @@ def benchmark_rust(): block_ids.append(block_id) alloc_time = time.perf_counter() - start print( - f"Allocate {n_allocs} blocks: {alloc_time*1000:.2f} ms ({n_allocs/alloc_time:.0f} ops/sec)" + f"Allocate {n_allocs} blocks: {alloc_time * 1000:.2f} ms ({n_allocs / alloc_time:.0f} ops/sec)" ) # Free benchmark @@ -34,7 +34,7 @@ def benchmark_rust(): pool.free(block_id) free_time = time.perf_counter() - start print( - f"Free {n_allocs} blocks: {free_time*1000:.2f} ms ({n_allocs/free_time:.0f} ops/sec)" + f"Free {n_allocs} blocks: {free_time * 1000:.2f} ms ({n_allocs / free_time:.0f} ops/sec)" ) # Reuse benchmark (allocate from free list) @@ -45,7 +45,7 @@ def benchmark_rust(): block_ids.append(block_id) reuse_time = time.perf_counter() - start print( - f"Reuse {n_allocs} blocks: {reuse_time*1000:.2f} ms ({n_allocs/reuse_time:.0f} ops/sec)" + f"Reuse {n_allocs} blocks: {reuse_time * 1000:.2f} ms ({n_allocs / reuse_time:.0f} ops/sec)" ) stats = pool.stats() @@ -70,14 +70,14 @@ def benchmark_rust(): sched.submit(task) submit_time = time.perf_counter() - start print( - f"Submit {n_tasks} tasks: {submit_time*1000:.2f} ms ({n_tasks/submit_time:.0f} ops/sec)" + f"Submit {n_tasks} tasks: {submit_time * 1000:.2f} ms ({n_tasks / submit_time:.0f} ops/sec)" ) # Get runnable benchmark start = time.perf_counter() runnable = sched.get_runnable_tasks(n_tasks) get_runnable_time = time.perf_counter() - start - print(f"Get runnable {len(runnable)} tasks: {get_runnable_time*1000:.2f} ms") + print(f"Get runnable {len(runnable)} tasks: {get_runnable_time * 1000:.2f} ms") # Complete benchmark start = time.perf_counter() @@ -85,7 +85,7 @@ def benchmark_rust(): sched.complete_task(task_id) complete_time = time.perf_counter() - start print( - f"Complete {len(runnable)} tasks: {complete_time*1000:.2f} ms ({len(runnable)/complete_time:.0f} ops/sec)" + f"Complete {len(runnable)} tasks: {complete_time * 1000:.2f} ms ({len(runnable) / complete_time:.0f} ops/sec)" ) stats = sched.stats() @@ -115,7 +115,7 @@ def benchmark_python(): blocks.append(block) alloc_time = time.perf_counter() - start print( - f"Allocate {n_allocs} blocks: {alloc_time*1000:.2f} ms ({n_allocs/alloc_time:.0f} ops/sec)" + f"Allocate {n_allocs} blocks: {alloc_time * 1000:.2f} ms ({n_allocs / alloc_time:.0f} ops/sec)" ) # Free benchmark @@ -124,7 +124,7 @@ def benchmark_python(): pool.free(block) free_time = time.perf_counter() - start print( - f"Free {n_allocs} blocks: {free_time*1000:.2f} ms ({n_allocs/free_time:.0f} ops/sec)" + f"Free {n_allocs} blocks: {free_time * 1000:.2f} ms ({n_allocs / free_time:.0f} ops/sec)" ) # Reuse benchmark (allocate from free list) @@ -135,7 +135,7 @@ def benchmark_python(): blocks.append(block) reuse_time = time.perf_counter() - start print( - f"Reuse {n_allocs} blocks: {reuse_time*1000:.2f} ms ({n_allocs/reuse_time:.0f} ops/sec)" + f"Reuse {n_allocs} blocks: {reuse_time * 1000:.2f} ms ({n_allocs / reuse_time:.0f} ops/sec)" ) stats = pool.stats() @@ -162,7 +162,7 @@ def benchmark_python(): tasks.append(task) submit_time = time.perf_counter() - start print( - f"Submit {n_tasks} tasks: {submit_time*1000:.2f} ms ({n_tasks/submit_time:.0f} ops/sec)" + f"Submit {n_tasks} tasks: {submit_time * 1000:.2f} ms ({n_tasks / submit_time:.0f} ops/sec)" ) # Note: Python scheduler has different API (run_once, etc.) diff --git a/demo_scheduler_log.py b/demo_scheduler_log.py index 73aeb1f..7467c32 100644 --- a/demo_scheduler_log.py +++ b/demo_scheduler_log.py @@ -21,7 +21,7 @@ def log(prefix: str, msg: str): def separator(title: str = ""): """Print separator line.""" if title: - print(f"\n{'='*20} {title} {'='*20}") + print(f"\n{'=' * 20} {title} {'=' * 20}") else: print("-" * 60) @@ -105,17 +105,18 @@ def run_simulation(): log("SUBMIT", f"Task '{name}' submitted (id={task_id[:8]})") log( - "SUBMIT", f" -> Policy={policy}, Memory={mem/1024/1024:.0f}MB, Bandwidth={bw*100:.0f}%" + "SUBMIT", + f" -> Policy={policy}, Memory={mem / 1024 / 1024:.0f}MB, Bandwidth={bw * 100:.0f}%", ) log("SUBMIT", "Total: 6 tasks submitted") log( "SUBMIT", - f" -> Memory requested: {total_memory_requested/1024/1024/1024:.2f} GB / 18.00 GB ({total_memory_requested*100/TOTAL_MEM:.1f}%)", + f" -> Memory requested: {total_memory_requested / 1024 / 1024 / 1024:.2f} GB / 18.00 GB ({total_memory_requested * 100 / TOTAL_MEM:.1f}%)", ) log( "SUBMIT", - f" -> Bandwidth requested: {total_bandwidth_requested*100:.0f}% (OVERCOMMIT DETECTED)", + f" -> Bandwidth requested: {total_bandwidth_requested * 100:.0f}% (OVERCOMMIT DETECTED)", ) # ========== Phase 4: Admission Control ========== @@ -129,18 +130,18 @@ def run_simulation(): log("ADMISSION", f"Evaluating task '{name}' (policy={policy})") if policy == "GUARANTEED": - log("ADMISSION", f" [CHECK] Memory: {mem/1024/1024:.0f}MB <= available (PASS)") - log("ADMISSION", f" [CHECK] Bandwidth: {bw*100:.0f}% guaranteed reservation (PASS)") + log("ADMISSION", f" [CHECK] Memory: {mem / 1024 / 1024:.0f}MB <= available (PASS)") + log("ADMISSION", f" [CHECK] Bandwidth: {bw * 100:.0f}% guaranteed reservation (PASS)") log("ADMISSION", " [CHECK] Priority: 100 (highest tier)") log("ADMISSION", " -> ADMIT (guaranteed resources reserved)") elif policy == "BURSTABLE": - log("ADMISSION", f" [CHECK] Memory: {mem/1024/1024:.0f}MB <= available (PASS)") - log("ADMISSION", f" [CHECK] Bandwidth: {bw*100:.0f}% soft limit (may throttle)") + log("ADMISSION", f" [CHECK] Memory: {mem / 1024 / 1024:.0f}MB <= available (PASS)") + log("ADMISSION", f" [CHECK] Bandwidth: {bw * 100:.0f}% soft limit (may throttle)") log("ADMISSION", " [CHECK] Priority: 50 (mid tier)") log("ADMISSION", " -> ADMIT (burst capacity available)") else: # BEST_EFFORT - log("ADMISSION", f" [CHECK] Memory: {mem/1024/1024:.0f}MB (opportunistic)") - log("ADMISSION", f" [CHECK] Bandwidth: {bw*100:.0f}% (no guarantee)") + log("ADMISSION", f" [CHECK] Memory: {mem / 1024 / 1024:.0f}MB (opportunistic)") + log("ADMISSION", f" [CHECK] Bandwidth: {bw * 100:.0f}% (no guarantee)") log("ADMISSION", " [CHECK] Priority: 10 (lowest tier)") log("ADMISSION", " -> ADMIT (best-effort, may be preempted)") @@ -182,15 +183,18 @@ def run_simulation(): log("ALLOC", f"Block {block_id}: {size_mb}MB for '{name}'") log( "ALLOC", - f" -> Size class: {size_class/1024/1024:.0f}MB, Internal frag: {(size_class-size_bytes)*100/size_class:.1f}%", + f" -> Size class: {size_class / 1024 / 1024:.0f}MB, Internal frag: {(size_class - size_bytes) * 100 / size_class:.1f}%", ) stats = pool.stats() separator() log("MEMPOOL", f"Allocation complete: {stats.active_blocks} active blocks") - log("MEMPOOL", f" -> Used: {stats.used/1024/1024/1024:.2f} GB ({stats.used*100/QUOTA:.1f}%)") - log("MEMPOOL", f" -> Cached: {stats.cached/1024/1024:.0f} MB") - log("MEMPOOL", f" -> Available: {stats.available/1024/1024/1024:.2f} GB") + log( + "MEMPOOL", + f" -> Used: {stats.used / 1024 / 1024 / 1024:.2f} GB ({stats.used * 100 / QUOTA:.1f}%)", + ) + log("MEMPOOL", f" -> Cached: {stats.cached / 1024 / 1024:.0f} MB") + log("MEMPOOL", f" -> Available: {stats.available / 1024 / 1024 / 1024:.2f} GB") log("MEMPOOL", f" -> cudaMalloc count: {stats.cudamalloc_count}") log("MEMPOOL", f" -> Reuse count: {stats.reuse_count}") @@ -279,7 +283,7 @@ def run_simulation(): if start_ms + duration_ms <= 65: log( "COMPLETE", - f"T+{start_ms+duration_ms:03d}ms: '{name}' FINISH (duration={duration_ms}ms)", + f"T+{start_ms + duration_ms:03d}ms: '{name}' FINISH (duration={duration_ms}ms)", ) if tid: sched.complete_task(tid) @@ -298,10 +302,10 @@ def run_simulation(): # Memory stats log("STATS", "=== Memory Pool Statistics ===") final_stats = pool.stats() - log("STATS", f" Quota: {final_stats.quota/1024/1024/1024:.2f} GB") + log("STATS", f" Quota: {final_stats.quota / 1024 / 1024 / 1024:.2f} GB") log("STATS", " Peak Used: 13.86 GB (77.0%)") - log("STATS", f" Final Used: {final_stats.used/1024/1024/1024:.2f} GB") - log("STATS", f" Cached: {final_stats.cached/1024/1024/1024:.2f} GB") + log("STATS", f" Final Used: {final_stats.used / 1024 / 1024 / 1024:.2f} GB") + log("STATS", f" Cached: {final_stats.cached / 1024 / 1024 / 1024:.2f} GB") log("STATS", f" Allocations: {final_stats.allocation_count}") log("STATS", f" cudaMalloc: {final_stats.cudamalloc_count}") log("STATS", f" Reuse: {final_stats.reuse_count}") @@ -314,7 +318,7 @@ def run_simulation(): log("STATS", f" Tasks Submitted: {sched_stats.total_submitted}") log("STATS", f" Tasks Completed: {sched_stats.completed_count}") log("STATS", f" Tasks Failed: {sched_stats.failed_count}") - log("STATS", f" Avg Wait Time: {sched_stats.avg_wait_time*1000:.2f} ms") + log("STATS", f" Avg Wait Time: {sched_stats.avg_wait_time * 1000:.2f} ms") log("STATS", " Avg Exec Time: 12.5 ms") separator() diff --git a/examples/benchmark_large.py b/examples/benchmark_large.py index a9acd8e..4b6a4c5 100644 --- a/examples/benchmark_large.py +++ b/examples/benchmark_large.py @@ -34,5 +34,5 @@ gflops = flops / (gpu_ms / 1000) / 1e9 print( - f"{size}x{size}: NumPy={numpy_ms:.1f}ms, GPU={gpu_ms:.1f}ms, Speedup={numpy_ms/gpu_ms:.1f}x, {gflops:.0f} GFLOPS" + f"{size}x{size}: NumPy={numpy_ms:.1f}ms, GPU={gpu_ms:.1f}ms, Speedup={numpy_ms / gpu_ms:.1f}x, {gflops:.0f} GFLOPS" ) diff --git a/examples/benchmark_tiled_matmul.py b/examples/benchmark_tiled_matmul.py index 7cec00b..f4bead3 100644 --- a/examples/benchmark_tiled_matmul.py +++ b/examples/benchmark_tiled_matmul.py @@ -77,7 +77,9 @@ C_result = C_gpu.to_numpy() rel_error = np.max(np.abs(C_result - C_cpu)) / np.max(np.abs(C_cpu)) - print(f"{size:>5}x{size:<5} | {kernel:<9} | {avg_time*1000:>8.2f} | {gflops:>7.1f} | {speedup:>5.1f}x") + print( + f"{size:>5}x{size:<5} | {kernel:<9} | {avg_time * 1000:>8.2f} | {gflops:>7.1f} | {speedup:>5.1f}x" + ) if rel_error > 1e-3: print(f" WARNING: High relative error: {rel_error:.2e}") diff --git a/examples/demo_gpu.py b/examples/demo_gpu.py index f5821b1..97647af 100644 --- a/examples/demo_gpu.py +++ b/examples/demo_gpu.py @@ -69,9 +69,9 @@ # Verify max_diff = np.max(np.abs(c_result - c_cpu)) - print(f" GPU time: {gpu_time*1000:.3f} ms") - print(f" CPU time: {cpu_time*1000:.3f} ms") - print(f" Speedup: {cpu_time/gpu_time:.2f}x") + print(f" GPU time: {gpu_time * 1000:.3f} ms") + print(f" CPU time: {cpu_time * 1000:.3f} ms") + print(f" Speedup: {cpu_time / gpu_time:.2f}x") print(f" Max diff: {max_diff:.2e} (should be ~0)") # Test 2: Element-wise Multiply @@ -88,9 +88,9 @@ cpu_time = time.perf_counter() - start max_diff = np.max(np.abs(c_result - c_cpu)) - print(f" GPU time: {gpu_time*1000:.3f} ms") - print(f" CPU time: {cpu_time*1000:.3f} ms") - print(f" Speedup: {cpu_time/gpu_time:.2f}x") + print(f" GPU time: {gpu_time * 1000:.3f} ms") + print(f" CPU time: {cpu_time * 1000:.3f} ms") + print(f" Speedup: {cpu_time / gpu_time:.2f}x") print(f" Max diff: {max_diff:.2e}") # Test 3: Matrix Multiplication @@ -114,9 +114,9 @@ max_diff = np.max(np.abs(C_result - C_cpu)) rel_error = max_diff / np.max(np.abs(C_cpu)) - print(f" GPU time: {gpu_time*1000:.3f} ms") - print(f" CPU time: {cpu_time*1000:.3f} ms") - print(f" Speedup: {cpu_time/gpu_time:.2f}x") + print(f" GPU time: {gpu_time * 1000:.3f} ms") + print(f" CPU time: {cpu_time * 1000:.3f} ms") + print(f" Speedup: {cpu_time / gpu_time:.2f}x") print(f" Max diff: {max_diff:.2e}") print(f" Rel error: {rel_error:.2e}") @@ -136,7 +136,7 @@ kernel = native.JITKernel(kernel_src, "scale_add") compile_time = time.perf_counter() - start - print(f" Compilation time: {compile_time*1000:.3f} ms") + print(f" Compilation time: {compile_time * 1000:.3f} ms") print(f" Kernel compiled: {kernel.is_compiled}") print(f" PTX length: {len(kernel.ptx)} bytes") @@ -164,9 +164,9 @@ cpu_time = time.perf_counter() - start gflops = 2 * M * N * K / gpu_time / 1e9 - print(f" GPU time: {gpu_time*1000:.3f} ms") - print(f" CPU time: {cpu_time*1000:.3f} ms") - print(f" Speedup: {cpu_time/gpu_time:.2f}x") + print(f" GPU time: {gpu_time * 1000:.3f} ms") + print(f" CPU time: {cpu_time * 1000:.3f} ms") + print(f" Speedup: {cpu_time / gpu_time:.2f}x") print(f" GPU GFLOPS: {gflops:.1f}") print("\n" + "=" * 60) diff --git a/examples/demo_optimized.py b/examples/demo_optimized.py index d40a07c..77babf2 100644 --- a/examples/demo_optimized.py +++ b/examples/demo_optimized.py @@ -70,9 +70,9 @@ max_diff = np.max(np.abs(result_np - expected)) print(f" Elements: {n:,}") - print(f" GPU time: {gpu_time*1000:.3f} ms") - print(f" CPU time: {cpu_time*1000:.3f} ms") - print(f" Speedup: {cpu_time/gpu_time:.2f}x") + print(f" GPU time: {gpu_time * 1000:.3f} ms") + print(f" CPU time: {cpu_time * 1000:.3f} ms") + print(f" Speedup: {cpu_time / gpu_time:.2f}x") print(f" Max diff: {max_diff:.2e}") # Test 2: Matrix multiplication chain @@ -99,9 +99,9 @@ rel_error = np.max(np.abs(result_np - expected)) / np.max(np.abs(expected)) print(f" Size: {M}x{M}") - print(f" GPU time: {gpu_time*1000:.3f} ms") - print(f" CPU time: {cpu_time*1000:.3f} ms") - print(f" Speedup: {cpu_time/gpu_time:.2f}x") + print(f" GPU time: {gpu_time * 1000:.3f} ms") + print(f" CPU time: {cpu_time * 1000:.3f} ms") + print(f" Speedup: {cpu_time / gpu_time:.2f}x") print(f" Rel error: {rel_error:.2e}") # Test 3: Large single operation (where data transfer dominates) @@ -139,12 +139,12 @@ gflops = 2 * M * M * M / gpu_compute / 1e9 transfer_overhead = (gpu_total - gpu_compute) / gpu_total * 100 - print(f" GPU total: {gpu_total*1000:.3f} ms (with H<->D transfer)") - print(f" GPU compute: {gpu_compute*1000:.3f} ms (data on GPU)") - print(f" CPU time: {cpu_time*1000:.3f} ms") + print(f" GPU total: {gpu_total * 1000:.3f} ms (with H<->D transfer)") + print(f" GPU compute: {gpu_compute * 1000:.3f} ms (data on GPU)") + print(f" CPU time: {cpu_time * 1000:.3f} ms") print(f" Transfer overhead: {transfer_overhead:.1f}%") print(f" GPU GFLOPS: {gflops:.1f}") - print(f" Speedup (compute only): {cpu_time/gpu_compute:.2f}x") + print(f" Speedup (compute only): {cpu_time / gpu_compute:.2f}x") # Test 4: Many small operations print("\n4. Many Small Operations: 100x add of 10K elements") @@ -171,9 +171,9 @@ result_cpu = result_cpu + b_np cpu_time = time.perf_counter() - start - print(f" GPU time: {gpu_time*1000:.3f} ms") - print(f" CPU time: {cpu_time*1000:.3f} ms") - print(f" Speedup: {cpu_time/gpu_time:.2f}x") + print(f" GPU time: {gpu_time * 1000:.3f} ms") + print(f" CPU time: {cpu_time * 1000:.3f} ms") + print(f" Speedup: {cpu_time / gpu_time:.2f}x") print("\n" + "=" * 70) print("Summary: Zero-copy operations significantly reduce overhead for") diff --git a/examples/demo_runtime_modes.py b/examples/demo_runtime_modes.py index b8d9d95..5202644 100644 --- a/examples/demo_runtime_modes.py +++ b/examples/demo_runtime_modes.py @@ -10,7 +10,6 @@ """ - def print_header(title: str) -> None: """Print a section header.""" print("\n" + "=" * 60) @@ -49,7 +48,7 @@ def demo_full_jit_mode() -> bool: # Demo: Custom JIT kernel print("\n [Demo] Custom JIT Kernel:") - kernel_source = ''' + kernel_source = """ extern "C" __global__ void scale_array(float* data, float factor, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -57,7 +56,7 @@ def demo_full_jit_mode() -> bool: data[idx] *= factor; } } - ''' + """ try: kernel = gp.jit(kernel_source, func="scale_array") @@ -176,10 +175,10 @@ def demo_cpu_simulation_mode() -> bool: print(f" - add(128x128): OK (CPU), result shape {C.shape}") # JIT also works in simulation (just marks as compiled) - kernel_source = ''' + kernel_source = """ extern "C" __global__ void dummy(float* x) {} - ''' + """ kernel = gp.jit(kernel_source, func="dummy") print(f" - jit kernel: OK (simulated), compiled={kernel.is_compiled}") diff --git a/examples/demo_v01.py b/examples/demo_v01.py index 4292222..12a4079 100644 --- a/examples/demo_v01.py +++ b/examples/demo_v01.py @@ -119,9 +119,9 @@ elapsed = time.perf_counter() - start print(f"Matrix multiplication {size}x{size}:") -print(f" Time: {elapsed*1000:.2f} ms") +print(f" Time: {elapsed * 1000:.2f} ms") print(f" Result shape: {C.shape}") -print(f" Result sample [0,0]: {result[0,0]:.4f}") +print(f" Result sample [0,0]: {result[0, 0]:.4f}") # Verify correctness A_np = A.to_numpy() diff --git a/examples/demo_v02.py b/examples/demo_v02.py index 1f5ffca..0a9d2da 100644 --- a/examples/demo_v02.py +++ b/examples/demo_v02.py @@ -29,26 +29,32 @@ # Header # ============================================================================= + def print_header(title: str): print("\n" + "=" * 70) print(f" {title}") print("=" * 70) + def print_section(title: str): print(f"\n--- {title} ---") + # ============================================================================= # Main Demo # ============================================================================= + def main(): print_header("PyGPUkit v0.2 Full Feature Demo") # Import modules try: import pygpukit + native = pygpukit._pygpukit_native import _pygpukit_rust as rust + print(f"PyGPUkit version: {pygpukit.__version__}") print("Native module loaded: OK") print("Rust module loaded: OK") @@ -77,7 +83,7 @@ def main(): pool = rust.MemoryPool( quota=100 * 1024 * 1024, # 100 MB quota - enable_eviction=True + enable_eviction=True, ) print("Created pool with 100 MB quota, eviction enabled") @@ -113,7 +119,7 @@ def main(): scheduler = rust.Scheduler( sched_tick_ms=10.0, window_ms=100.0, - total_memory=1024 * 1024 * 1024 # 1 GB + total_memory=1024 * 1024 * 1024, # 1 GB ) print("Created scheduler (10ms tick, 100ms window, 1GB memory)") @@ -124,7 +130,7 @@ def main(): id=f"task_{i}", name=f"Layer {i}", memory_estimate=100 * 1024 * 1024, # 100 MB - priority=i % 3 + priority=i % 3, ) task_id = scheduler.submit(task) task_ids.append(task_id) @@ -170,7 +176,7 @@ def main(): src_ptr=0x1000 + i * 0x1000, dst_ptr=0x2000 + i * 0x1000, size=1024 * 1024, # 1 MB - priority=i % 3 + priority=i % 3, ) transfer_ids.append(op_id) print(f" Queued transfer {op_id}: {type_name.upper()}, priority={i % 3}") @@ -207,13 +213,9 @@ def main(): grid=(128, 1, 1), block=(256, 1, 1), shared_mem=0, - stream_id=i % 2 # Alternate between stream 0 and 1 - ) - req_id = dispatcher.queue( - kernel_handle=0xDEADBEEF + i, - config=config, - priority=i % 3 + stream_id=i % 2, # Alternate between stream 0 and 1 ) + req_id = dispatcher.queue(kernel_handle=0xDEADBEEF + i, config=config, priority=i % 3) print(f" Queued kernel {req_id}: stream={i % 2}, priority={i % 3}") # Get ready kernels @@ -285,23 +287,29 @@ def main(): C_result = C_gpu.to_numpy() rel_error = np.max(np.abs(C_result - C_cpu)) / np.max(np.abs(C_cpu)) - results.append({ - 'size': size, - 'kernel': kernel, - 'time_ms': avg_time * 1000, - 'gflops': gflops, - 'speedup': speedup, - 'error': rel_error - }) + results.append( + { + "size": size, + "kernel": kernel, + "time_ms": avg_time * 1000, + "gflops": gflops, + "speedup": speedup, + "error": rel_error, + } + ) status = "OK" if rel_error < 1e-3 else f"ERR:{rel_error:.1e}" - print(f"{size:>5}x{size:<5} | {kernel:<9} | {avg_time*1000:>8.2f} | {gflops:>7.1f} | {speedup:>5.1f}x ({status})") + print( + f"{size:>5}x{size:<5} | {kernel:<9} | {avg_time * 1000:>8.2f} | {gflops:>7.1f} | {speedup:>5.1f}x ({status})" + ) print("-" * 60) # Peak performance - peak = max(results, key=lambda x: x['gflops']) - print(f"\nPeak: {peak['gflops']:.1f} GFLOPS at {peak['size']}x{peak['size']} ({peak['kernel']})") + peak = max(results, key=lambda x: x["gflops"]) + print( + f"\nPeak: {peak['gflops']:.1f} GFLOPS at {peak['size']}x{peak['size']} ({peak['kernel']})" + ) # ========================================================================= # 6. Integrated Demo - Full Pipeline @@ -335,7 +343,7 @@ def main(): id=f"layer_{layer}", name=f"Layer {layer}", memory_estimate=batch_size * hidden_dim * 4 * 2, # input + output - priority=0 + priority=0, ) task_id = scheduler.submit(task) @@ -346,18 +354,13 @@ def main(): # Queue transfer transfer_engine.enqueue_h2d( - host_ptr=0x1000, - device_ptr=input_block_id, - size=batch_size * hidden_dim * 4 + host_ptr=0x1000, device_ptr=input_block_id, size=batch_size * hidden_dim * 4 ) # Queue kernel config = rust.LaunchConfig.linear(batch_size * hidden_dim, 256) dispatcher.queue( - kernel_handle=0xFFFF0000 + layer, - config=config, - task_id=task_id, - priority=0 + kernel_handle=0xFFFF0000 + layer, config=config, task_id=task_id, priority=0 ) # Actual compute @@ -384,7 +387,7 @@ def main(): throughput = total_flops / total_time / 1e9 print("\nPipeline completed:") - print(f" Total time: {total_time*1000:.2f} ms") + print(f" Total time: {total_time * 1000:.2f} ms") print(f" Throughput: {throughput:.1f} GFLOPS") print(f" Tasks completed: {scheduler.stats().completed_count}") print(f" Memory allocations: {pool.stats().allocation_count}") @@ -434,5 +437,6 @@ def main(): return 0 + if __name__ == "__main__": sys.exit(main()) diff --git a/examples/demo_v023.py b/examples/demo_v023.py index b62b675..352255f 100644 --- a/examples/demo_v023.py +++ b/examples/demo_v023.py @@ -118,14 +118,14 @@ def main(): fp32_rel_err = np.max(np.abs(result_fp32 - expected) / (np.abs(expected) + 1e-8)) print("\nFP32 Error:") print(f" Max absolute error: {fp32_abs_err:.6e}") - print(f" Max relative error: {fp32_rel_err:.6e} ({fp32_rel_err*100:.4f}%)") + print(f" Max relative error: {fp32_rel_err:.6e} ({fp32_rel_err * 100:.4f}%)") # TF32 error (expected to be higher due to reduced precision) tf32_abs_err = np.max(np.abs(result_tf32 - expected)) tf32_rel_err = np.max(np.abs(result_tf32 - expected) / (np.abs(expected) + 1e-8)) print("\nTF32 Error:") print(f" Max absolute error: {tf32_abs_err:.6e}") - print(f" Max relative error: {tf32_rel_err:.6e} ({tf32_rel_err*100:.4f}%)") + print(f" Max relative error: {tf32_rel_err:.6e} ({tf32_rel_err * 100:.4f}%)") # TF32 typically has ~0.1% error per op, accumulating to ~1-5% for large K if tf32_rel_err < 0.1: # 10% threshold @@ -185,17 +185,14 @@ def main(): speedup = tf32_tflops / fp32_tflops if fp32_tflops > 0 else 0 - print(f" FP32: {fp32_tflops:6.2f} TFLOPS ({fp32_time*1000:.2f} ms)") - print(f" TF32: {tf32_tflops:6.2f} TFLOPS ({tf32_time*1000:.2f} ms)") + print(f" FP32: {fp32_tflops:6.2f} TFLOPS ({fp32_time * 1000:.2f} ms)") + print(f" TF32: {tf32_tflops:6.2f} TFLOPS ({tf32_time * 1000:.2f} ms)") print(f" Speedup: {speedup:.2f}x") print() - results.append({ - "size": f"{M}x{N}x{K}", - "fp32": fp32_tflops, - "tf32": tf32_tflops, - "speedup": speedup - }) + results.append( + {"size": f"{M}x{N}x{K}", "fp32": fp32_tflops, "tf32": tf32_tflops, "speedup": speedup} + ) # ========================================================================= # 5. Summary diff --git a/examples/demo_v025.py b/examples/demo_v025.py index 0401b10..edc3478 100644 --- a/examples/demo_v025.py +++ b/examples/demo_v025.py @@ -52,7 +52,7 @@ def benchmark_matmul(a, b, name: str, warmup: int = 3, iterations: int = 10) -> flops = 2.0 * M * N * K tflops = flops / avg_time / 1e12 - print(f" {name}: {avg_time*1000:.2f} ms, {tflops:.2f} TFLOPS") + print(f" {name}: {avg_time * 1000:.2f} ms, {tflops:.2f} TFLOPS") return tflops @@ -104,10 +104,10 @@ def demo_elementwise(): b = gpk.from_numpy(b_np) # Operations - add_result = (a + b) - mul_result = (a * b) - sub_result = (a - b) - div_result = (a / b) + add_result = a + b + mul_result = a * b + sub_result = a - b + div_result = a / b # Convert back for display if gpk_dtype == gpk.bfloat16: @@ -242,8 +242,8 @@ def demo_benchmark_full(): start = time.perf_counter() _ = (a @ b).to_numpy() times.append(time.perf_counter() - start) - flops = 2.0 * size ** 3 - results['FP32'] = flops / np.mean(times) / 1e12 + flops = 2.0 * size**3 + results["FP32"] = flops / np.mean(times) / 1e12 # TF32 for _ in range(3): @@ -254,7 +254,7 @@ def demo_benchmark_full(): start = time.perf_counter() _ = gpk.matmul(a, b, use_tf32=True).to_numpy() times.append(time.perf_counter() - start) - results['TF32'] = flops / np.mean(times) / 1e12 + results["TF32"] = flops / np.mean(times) / 1e12 # FP16 a16 = gpk.from_numpy(np.random.randn(size, size).astype(np.float16)) @@ -268,7 +268,7 @@ def demo_benchmark_full(): start = time.perf_counter() _ = (a16 @ b16).to_numpy() times.append(time.perf_counter() - start) - results['FP16'] = flops / np.mean(times) / 1e12 + results["FP16"] = flops / np.mean(times) / 1e12 # BF16 abf = gpk.from_numpy(np.random.randn(size, size).astype(np.float32)).astype(gpk.bfloat16) @@ -282,9 +282,11 @@ def demo_benchmark_full(): start = time.perf_counter() _ = (abf @ bbf).to_numpy() times.append(time.perf_counter() - start) - results['BF16'] = flops / np.mean(times) / 1e12 + results["BF16"] = flops / np.mean(times) / 1e12 - print(f"{size}x{size:<7} {results['FP32']:<10.2f} {results['TF32']:<10.2f} {results['FP16']:<10.2f} {results['BF16']:<10.2f}") + print( + f"{size}x{size:<7} {results['FP32']:<10.2f} {results['TF32']:<10.2f} {results['FP16']:<10.2f} {results['BF16']:<10.2f}" + ) def main(): diff --git a/examples/demo_v02_full.py b/examples/demo_v02_full.py index d76330b..b003c40 100644 --- a/examples/demo_v02_full.py +++ b/examples/demo_v02_full.py @@ -41,26 +41,32 @@ # Header # ============================================================================= + def print_header(title: str): print("\n" + "=" * 70) print(f" {title}") print("=" * 70) + def print_section(title: str): print(f"\n--- {title} ---") + # ============================================================================= # Main Demo # ============================================================================= + def main(): print_header("PyGPUkit v0.2 Complete Feature Demo") # Import modules try: import pygpukit + native = pygpukit._pygpukit_native import _pygpukit_rust as rust + print(f"PyGPUkit version: {pygpukit.__version__}") print("Native module loaded: OK") print("Rust module loaded: OK") @@ -89,7 +95,7 @@ def main(): pool = rust.MemoryPool( quota=100 * 1024 * 1024, # 100 MB quota - enable_eviction=True + enable_eviction=True, ) print("Created pool with 100 MB quota, eviction enabled") @@ -125,7 +131,7 @@ def main(): scheduler = rust.Scheduler( sched_tick_ms=10.0, window_ms=100.0, - total_memory=1024 * 1024 * 1024 # 1 GB + total_memory=1024 * 1024 * 1024, # 1 GB ) print("Created scheduler (10ms tick, 100ms window, 1GB memory)") @@ -136,7 +142,7 @@ def main(): id=f"task_{i}", name=f"Layer {i}", memory_estimate=100 * 1024 * 1024, # 100 MB - priority=i % 3 + priority=i % 3, ) task_id = scheduler.submit(task) task_ids.append(task_id) @@ -171,7 +177,7 @@ def main(): src_ptr=0x1000 + i * 0x1000, dst_ptr=0x2000 + i * 0x1000, size=1024 * 1024, - priority=i % 3 + priority=i % 3, ) print(f" Queued transfer {op_id}: {type_name.upper()}") @@ -182,7 +188,9 @@ def main(): transfer_engine.complete_transfer(op.id) transfer_stats = transfer_engine.stats() - print(f"Transfer stats: {transfer_stats.completed_count} completed, {transfer_stats.pending_count} pending") + print( + f"Transfer stats: {transfer_stats.completed_count} completed, {transfer_stats.pending_count} pending" + ) # ========================================================================= # 4. Rust Kernel Dispatch Controller Demo @@ -194,16 +202,9 @@ def main(): for i in range(4): config = rust.LaunchConfig( - grid=(128, 1, 1), - block=(256, 1, 1), - shared_mem=0, - stream_id=i % 2 - ) - req_id = dispatcher.queue( - kernel_handle=0xDEADBEEF + i, - config=config, - priority=i % 3 + grid=(128, 1, 1), block=(256, 1, 1), shared_mem=0, stream_id=i % 2 ) + req_id = dispatcher.queue(kernel_handle=0xDEADBEEF + i, config=config, priority=i % 3) print(f" Queued kernel {req_id}: stream={i % 2}") ready_kernels = dispatcher.get_ready(max_requests=4) @@ -212,7 +213,9 @@ def main(): dispatcher.mark_completed(req.id) dispatch_stats = dispatcher.stats() - print(f"Dispatch stats: {dispatch_stats.completed_count} completed, {dispatch_stats.pending_count} pending") + print( + f"Dispatch stats: {dispatch_stats.completed_count} completed, {dispatch_stats.pending_count} pending" + ) # ========================================================================= # 5. Admission Control (NEW) @@ -225,7 +228,7 @@ def main(): admission_scheduler = rust.Scheduler( sched_tick_ms=10.0, window_ms=100.0, - total_memory=500 * 1024 * 1024 # 500 MB limit + total_memory=500 * 1024 * 1024, # 500 MB limit ) # Submit tasks that should fit @@ -234,7 +237,7 @@ def main(): id=f"admit_{i}", name=f"Admissible Task {i}", memory_estimate=100 * 1024 * 1024, # 100 MB each - priority=1 + priority=1, ) task_id = admission_scheduler.submit(task) print(f" Admitted task {i}: {task_id}") @@ -254,7 +257,7 @@ def main(): # Create QoS policy evaluator qos_evaluator = rust.QosPolicyEvaluator( total_memory=1024 * 1024 * 1024, # 1 GB - total_bandwidth=1.0 + total_bandwidth=1.0, ) # Test different QoS classes @@ -270,7 +273,9 @@ def main(): class_name = qos_class_names.get(int(task.qos_class), "Unknown") if eval_result.is_admitted(): qos_evaluator.reserve(eval_result) - print(f" {class_name:12} | {task.name:15} | ADMITTED (priority={task.effective_priority()})") + print( + f" {class_name:12} | {task.name:15} | ADMITTED (priority={task.effective_priority()})" + ) elif eval_result.is_throttled(): print(f" {class_name:12} | {task.name:15} | THROTTLED") else: @@ -289,10 +294,7 @@ def main(): print_header("7. Kernel Pacing Engine (NEW)") pacing_config = rust.PacingConfig( - total_bandwidth=1.0, - window_ms=100.0, - min_interval_ms=0.1, - adaptive=True + total_bandwidth=1.0, window_ms=100.0, min_interval_ms=0.1, adaptive=True ) pacing_engine = rust.KernelPacingEngine(pacing_config) print(f"Created pacing engine: {pacing_config}") @@ -326,21 +328,14 @@ def main(): print_header("8. Micro-Slicing Framework (NEW)") slice_config = rust.SliceConfig( - max_items_per_slice=10000, - max_duration_ms=1.0, - min_slices=2, - max_slices=16, - adaptive=True + max_items_per_slice=10000, max_duration_ms=1.0, min_slices=2, max_slices=16, adaptive=True ) slice_scheduler = rust.SliceScheduler(slice_config) print(f"Created slice scheduler: {slice_config}") # Submit kernels for slicing num_slices_1 = slice_scheduler.submit( - kernel_handle=0xAAAA0001, - total_items=50000, - block=(256, 1, 1), - shared_mem=0 + kernel_handle=0xAAAA0001, total_items=50000, block=(256, 1, 1), shared_mem=0 ) print(f"\nKernel 1: 50000 items -> {num_slices_1} slices") @@ -350,7 +345,7 @@ def main(): total_items=30000, block=(256, 1, 1), shared_mem=0, - priority=100 + priority=100, ) print(f"Kernel 2: 30000 items -> {num_slices_2} slices (priority=100)") @@ -361,7 +356,9 @@ def main(): slice_info = slice_scheduler.get_next_slice() if slice_info is None: break - print(f" Slice {slice_info.slice_id}: kernel=0x{slice_info.kernel_handle:X}, offset={slice_info.offset}, count={slice_info.count}") + print( + f" Slice {slice_info.slice_id}: kernel=0x{slice_info.kernel_handle:X}, offset={slice_info.offset}, count={slice_info.count}" + ) slice_scheduler.complete_slice(0.1) # 0.1ms exec time executed += 1 @@ -379,7 +376,7 @@ def main(): pinned_config = rust.PinnedPoolConfig( max_size=256 * 1024 * 1024, # 256 MB enable_pooling=True, - alignment=256 + alignment=256, ) pinned_manager = rust.PinnedMemoryManager(pinned_config) print(f"Created pinned memory manager: {pinned_config}") @@ -422,7 +419,7 @@ def main(): max_entries=1024, max_ptx_size=256 * 1024 * 1024, # 256 MB enable_eviction=True, - ttl_seconds=0.0 # No TTL + ttl_seconds=0.0, # No TTL ) kernel_cache = rust.KernelCache(cache_config) print(f"Created kernel cache: {cache_config}") @@ -435,7 +432,10 @@ def main(): kernels = [ ("__global__ void add_kernel(float* a, float* b, float* c) { ... }", "add_kernel"), ("__global__ void mul_kernel(float* a, float* b, float* c) { ... }", "mul_kernel"), - ("__global__ void matmul_kernel(float* A, float* B, float* C, int M, int N, int K) { ... }", "matmul_kernel"), + ( + "__global__ void matmul_kernel(float* A, float* B, float* C, int M, int N, int K) { ... }", + "matmul_kernel", + ), ] for source, name in kernels: @@ -474,20 +474,30 @@ def main(): partition_config = rust.PartitionConfig( total_memory=8 * 1024 * 1024 * 1024, # 8 GB allow_overcommit=False, - overcommit_ratio=1.0 + overcommit_ratio=1.0, ) partition_manager = rust.PartitionManager(partition_config) print(f"Created partition manager: {partition_config}") # Create partitions partitions = [ - ("inference", "Inference Workload", rust.PartitionLimits.with_memory(4 * 1024 * 1024 * 1024).compute(0.5).bandwidth(0.4)), - ("training", "Training Workload", rust.PartitionLimits.with_memory(3 * 1024 * 1024 * 1024).compute(0.4).bandwidth(0.5)), + ( + "inference", + "Inference Workload", + rust.PartitionLimits.with_memory(4 * 1024 * 1024 * 1024).compute(0.5).bandwidth(0.4), + ), + ( + "training", + "Training Workload", + rust.PartitionLimits.with_memory(3 * 1024 * 1024 * 1024).compute(0.4).bandwidth(0.5), + ), ] for pid, name, limits in partitions: partition_manager.create_partition(pid, name, limits) - print(f" Created partition '{pid}': memory={limits.memory_quota / 1024**3:.0f}GB, compute={limits.compute_share:.0%}") + print( + f" Created partition '{pid}': memory={limits.memory_quota / 1024**3:.0f}GB, compute={limits.compute_share:.0%}" + ) # Assign tasks partition_manager.assign_task("inference-task-1", "inference") @@ -560,22 +570,28 @@ def main(): C_result = C_gpu.to_numpy() rel_error = np.max(np.abs(C_result - C_cpu)) / np.max(np.abs(C_cpu)) - results.append({ - 'size': size, - 'kernel': kernel, - 'time_ms': avg_time * 1000, - 'gflops': gflops, - 'speedup': speedup, - 'error': rel_error - }) + results.append( + { + "size": size, + "kernel": kernel, + "time_ms": avg_time * 1000, + "gflops": gflops, + "speedup": speedup, + "error": rel_error, + } + ) status = "OK" if rel_error < 1e-3 else f"ERR:{rel_error:.1e}" - print(f"{size:>5}x{size:<5} | {kernel:<9} | {avg_time*1000:>8.2f} | {gflops:>7.1f} | {speedup:>5.1f}x ({status})") + print( + f"{size:>5}x{size:<5} | {kernel:<9} | {avg_time * 1000:>8.2f} | {gflops:>7.1f} | {speedup:>5.1f}x ({status})" + ) print("-" * 60) - peak = max(results, key=lambda x: x['gflops']) - print(f"\nPeak: {peak['gflops']:.1f} GFLOPS at {peak['size']}x{peak['size']} ({peak['kernel']})") + peak = max(results, key=lambda x: x["gflops"]) + print( + f"\nPeak: {peak['gflops']:.1f} GFLOPS at {peak['size']}x{peak['size']} ({peak['kernel']})" + ) # ========================================================================= # Summary @@ -612,5 +628,6 @@ def main(): return 0 + if __name__ == "__main__": sys.exit(main()) diff --git a/examples/scheduler_simulation.py b/examples/scheduler_simulation.py index 669a70e..6b80936 100644 --- a/examples/scheduler_simulation.py +++ b/examples/scheduler_simulation.py @@ -27,7 +27,7 @@ def log(msg: str, level: str = "INFO") -> None: def separator(title: str = "") -> None: """Print separator line.""" if title: - print(f"\n{'='*20} {title} {'='*20}") + print(f"\n{'=' * 20} {title} {'=' * 20}") else: print("-" * 60) @@ -68,7 +68,7 @@ def main() -> None: ┌─────────────────────────────────────────────────────────┐ │ Resource │ Total │ Available │ ├─────────────────────────────────────────────────────────┤ - │ GPU Memory │ {props.total_memory/(1024**3):6.2f} GB │ {props.total_memory/(1024**3):6.2f} GB │ + │ GPU Memory │ {props.total_memory / (1024**3):6.2f} GB │ {props.total_memory / (1024**3):6.2f} GB │ │ Streaming MPs │ {props.multiprocessor_count:6d} SMs │ {props.multiprocessor_count:6d} SMs │ │ Bandwidth │ 100.00 % │ 100.00 % │ │ CUDA Streams │ 32 │ 32 │ @@ -143,12 +143,12 @@ def make_workload(name: str, flops: int, duration_ms: float): def workload(): start = time.time() execution_log.append((name, "START", start)) - log(f"[KERNEL] {name}: Launching kernel (est. {flops/1e9:.1f} GFLOPS)", "EXEC") + log(f"[KERNEL] {name}: Launching kernel (est. {flops / 1e9:.1f} GFLOPS)", "EXEC") # Simulate work time.sleep(duration_ms / 1000.0) end = time.time() execution_log.append((name, "END", end)) - log(f"[KERNEL] {name}: Completed in {(end-start)*1000:.2f} ms", "EXEC") + log(f"[KERNEL] {name}: Completed in {(end - start) * 1000:.2f} ms", "EXEC") return workload @@ -185,9 +185,9 @@ def workload(): │ ID: {task_id:<8} │ │ Name: {name:<20} │ │ Memory Request: {mem_mb:>6} MB │ - │ Bandwidth: {bw*100:>5.1f} % │ + │ Bandwidth: {bw * 100:>5.1f} % │ │ Policy: {policy.upper():<12} │ - │ Est. FLOPs: {flops/1e12:.2f} TFLOPS │ + │ Est. FLOPs: {flops / 1e12:.2f} TFLOPS │ │ State: {task.state.name:<10} │ └──────────────────────────────────────────────────────────┘ """) @@ -208,7 +208,7 @@ def workload(): ┌─────────────────────────────────────────────────────────┐ │ Resource │ Reserved │ Available │ ├─────────────────────────────────────────────────────────┤ - │ GPU Memory │ {global_stats['reserved_memory']/(1024**2):6.0f} MB │ {avail_mem/(1024**2):6.0f} MB ({avail_pct:.1f}%) │ + │ GPU Memory │ {global_stats["reserved_memory"] / (1024**2):6.0f} MB │ {avail_mem / (1024**2):6.0f} MB ({avail_pct:.1f}%) │ │ Bandwidth │ 140.0 % │ -40.0 % (!) │ └─────────────────────────────────────────────────────────┘ @@ -280,12 +280,12 @@ def workload(): ] for name, size in alloc_sizes: - log(f"pool.allocate({size // (1024*1024)} MB) for {name}", "ALLOC") + log(f"pool.allocate({size // (1024 * 1024)} MB) for {name}", "ALLOC") try: block = pool.allocate(size) blocks.append((name, block)) stats = pool.stats() - log(f" Block ID: {block.id}, Size class: {block.size // (1024*1024)} MB", "ALLOC") + log(f" Block ID: {block.id}, Size class: {block.size // (1024 * 1024)} MB", "ALLOC") log( f" Pool used: {stats['used'] // (1024**2)} MB, Cached: {stats['cached'] // (1024**2)} MB", "ALLOC", @@ -318,11 +318,11 @@ def workload(): ┌────────────────────────────────────────┐ │ Metric │ Value │ ├────────────────────────────────────────┤ - │ cudaMalloc calls │ {stats['cudamalloc_count']:>6} │ - │ Reuse count │ {stats['reuse_count']:>6} │ - │ Eviction count │ {stats['eviction_count']:>6} │ - │ Active blocks │ {stats['active_blocks']:>6} │ - │ Free blocks │ {stats['free_blocks']:>6} │ + │ cudaMalloc calls │ {stats["cudamalloc_count"]:>6} │ + │ Reuse count │ {stats["reuse_count"]:>6} │ + │ Eviction count │ {stats["eviction_count"]:>6} │ + │ Active blocks │ {stats["active_blocks"]:>6} │ + │ Free blocks │ {stats["free_blocks"]:>6} │ └────────────────────────────────────────┘ """) @@ -396,10 +396,10 @@ def workload(): ┌────────────────────────────────────────┐ │ Metric │ Value │ ├────────────────────────────────────────┤ - │ Total tasks │ {final_stats['task_count']:>6} │ - │ Completed │ {final_stats['completed_count']:>6} │ + │ Total tasks │ {final_stats["task_count"]:>6} │ + │ Completed │ {final_stats["completed_count"]:>6} │ │ Total time │ {total_time:>6.1f} ms │ - │ Avg task time │ {total_time/len(task_ids):>6.1f} ms │ + │ Avg task time │ {total_time / len(task_ids):>6.1f} ms │ └────────────────────────────────────────┘ Per-Task Statistics: diff --git a/src/pygpukit/core/array.py b/src/pygpukit/core/array.py index a5c69b3..c32f32f 100644 --- a/src/pygpukit/core/array.py +++ b/src/pygpukit/core/array.py @@ -240,26 +240,31 @@ def __del__(self) -> None: def __add__(self, other: GPUArray) -> GPUArray: """Element-wise addition.""" from pygpukit.ops.basic import add + return add(self, other) def __sub__(self, other: GPUArray) -> GPUArray: """Element-wise subtraction.""" from pygpukit.ops.basic import sub + return sub(self, other) def __mul__(self, other: GPUArray) -> GPUArray: """Element-wise multiplication.""" from pygpukit.ops.basic import mul + return mul(self, other) def __truediv__(self, other: GPUArray) -> GPUArray: """Element-wise division.""" from pygpukit.ops.basic import div + return div(self, other) def __matmul__(self, other: GPUArray) -> GPUArray: """Matrix multiplication.""" from pygpukit.ops.basic import matmul + return matmul(self, other) # ======================================================================== @@ -310,5 +315,5 @@ def astype(self, dtype: DataType) -> GPUArray: return result else: target_np_dtype = dtype.to_numpy_dtype() - converted = np_data.astype(target_np_dtype) + converted: np.ndarray = np_data.astype(target_np_dtype) return from_numpy(converted) diff --git a/src/pygpukit/core/device.py b/src/pygpukit/core/device.py index a41e7a0..7e5a2d4 100644 --- a/src/pygpukit/core/device.py +++ b/src/pygpukit/core/device.py @@ -65,6 +65,7 @@ def get_device_info(device_id: int = 0) -> DeviceInfo: @dataclass class FallbackDeviceCapabilities: """Fallback DeviceCapabilities when Rust module is not available.""" + device_id: int name: str sm_version: int @@ -107,6 +108,7 @@ def get_device_capabilities(device_id: int = 0): # Try to use Rust DeviceCapabilities try: from pygpukit._pygpukit_rust import DeviceCapabilities + return DeviceCapabilities(sm_version) except ImportError: pass diff --git a/src/pygpukit/jit/compiler.py b/src/pygpukit/jit/compiler.py index b580caf..aaad95b 100644 --- a/src/pygpukit/jit/compiler.py +++ b/src/pygpukit/jit/compiler.py @@ -393,9 +393,7 @@ def _compile_native(self) -> None: # Retry loop for transient errors for retry in range(self._MAX_RETRIES): try: - self._kernel = native.JITKernel( - self._source, self._name, current_options - ) + self._kernel = native.JITKernel(self._source, self._name, current_options) self._ptx = self._kernel.ptx self._is_compiled = self._kernel.is_compiled @@ -476,8 +474,7 @@ def _prepare_compile_options(self, native: Any) -> list[str]: # Check if user already specified -arch has_arch = any( - opt.startswith("-arch=") or opt.startswith("--gpu-architecture=") - for opt in options + opt.startswith("-arch=") or opt.startswith("--gpu-architecture=") for opt in options ) if not has_arch: @@ -646,12 +643,12 @@ def jit( _warmup_error: Exception | None = None # Warmup test kernel -_WARMUP_KERNEL_SOURCE = ''' +_WARMUP_KERNEL_SOURCE = """ extern "C" __global__ void _pygpukit_warmup_kernel(float* x, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) x[idx] = x[idx]; } -''' +""" def warmup( @@ -714,9 +711,7 @@ def _do_warmup(callback: Callable[[], None] | None = None) -> bool: try: # Check if NVRTC is available if not is_nvrtc_available(): - _warmup_error = NvrtcError( - "NVRTC not available", NvrtcErrorCode.NotLoaded - ) + _warmup_error = NvrtcError("NVRTC not available", NvrtcErrorCode.NotLoaded) _warmup_done = True return False diff --git a/src/pygpukit/ops/basic.py b/src/pygpukit/ops/basic.py index 356de0b..49272a7 100644 --- a/src/pygpukit/ops/basic.py +++ b/src/pygpukit/ops/basic.py @@ -24,6 +24,7 @@ def _validate_same_dtype(a: GPUArray, b: GPUArray, op_name: str) -> None: def _validate_float_dtype(a: GPUArray, op_name: str) -> None: """Validate that array has float dtype.""" from pygpukit.core.dtypes import bfloat16, float16, float32, float64 + if a.dtype not in (float32, float64, float16, bfloat16): raise ValueError(f"{op_name} requires float dtype, got {a.dtype}") @@ -367,6 +368,7 @@ def matmul(a: GPUArray, b: GPUArray, *, use_tf32: bool | None = None) -> GPUArra # Check TF32 dtype requirement early (before backend dispatch) if use_tf32 is True: from pygpukit.core.dtypes import float32 + if a.dtype != float32: raise RuntimeError("TF32 matmul requires float32 dtype") diff --git a/tests/stress_test.py b/tests/stress_test.py index 411518f..d92b7d2 100644 --- a/tests/stress_test.py +++ b/tests/stress_test.py @@ -3,6 +3,7 @@ Tests Rust backend components under sustained load. Default: 5 minutes runtime. """ + import argparse import random import threading @@ -19,6 +20,7 @@ class StressTestStats: """Thread-safe statistics collector.""" + def __init__(self): self.lock = threading.Lock() self.operations = 0 @@ -176,14 +178,14 @@ def stress_qos_evaluator(stats: StressTestStats, duration_sec: float): if qos_type == "guaranteed": task = rust.QosTaskMeta.guaranteed( - task_id, f"Guaranteed {task_counter}", - random.randint(1024, 50 * 1024 * 1024) + task_id, f"Guaranteed {task_counter}", random.randint(1024, 50 * 1024 * 1024) ) elif qos_type == "burstable": task = rust.QosTaskMeta.burstable( - task_id, f"Burstable {task_counter}", + task_id, + f"Burstable {task_counter}", random.randint(1024, 30 * 1024 * 1024), - random.uniform(1.5, 3.0) + random.uniform(1.5, 3.0), ) else: task = rust.QosTaskMeta.best_effort(task_id, f"BestEffort {task_counter}") @@ -222,9 +224,11 @@ def stress_partition_manager(stats: StressTestStats, duration_sec: float): if action == "create" and len(partitions) < 10: pid = f"partition-{partition_counter}" - limits = rust.PartitionLimits().memory( - random.randint(100 * 1024 * 1024, 500 * 1024 * 1024) - ).compute(random.uniform(0.05, 0.3)) + limits = ( + rust.PartitionLimits() + .memory(random.randint(100 * 1024 * 1024, 500 * 1024 * 1024)) + .compute(random.uniform(0.05, 0.3)) + ) manager.create_partition(pid, f"Partition {partition_counter}", limits) partitions.append(pid) partition_counter += 1 @@ -292,8 +296,10 @@ def run_stress_test(duration_minutes: float = 5.0, workers: int = 4): elapsed = now - start_time current_stats = stats.get_stats() ops_per_sec = current_stats["operations"] / elapsed if elapsed > 0 else 0 - print(f"[{elapsed:.0f}s] Ops: {current_stats['operations']:,} " - f"({ops_per_sec:.0f}/s) | Errors: {current_stats['errors']}") + print( + f"[{elapsed:.0f}s] Ops: {current_stats['operations']:,} " + f"({ops_per_sec:.0f}/s) | Errors: {current_stats['errors']}" + ) last_report = now # Wait for all to complete @@ -321,7 +327,7 @@ def run_stress_test(duration_minutes: float = 5.0, workers: int = 4): print(f" Partitioning: {final_stats['partition_ops']:,}") print("-" * 50) - if final_stats['errors'] > 0: + if final_stats["errors"] > 0: print(f"WARNING: {final_stats['errors']} errors occurred during test") return 1 else: @@ -331,10 +337,10 @@ def run_stress_test(duration_minutes: float = 5.0, workers: int = 4): if __name__ == "__main__": parser = argparse.ArgumentParser(description="PyGPUkit Stress Test") - parser.add_argument("--duration", type=float, default=5.0, - help="Test duration in minutes (default: 5)") - parser.add_argument("--workers", type=int, default=4, - help="Workers per component (default: 4)") + parser.add_argument( + "--duration", type=float, default=5.0, help="Test duration in minutes (default: 5)" + ) + parser.add_argument("--workers", type=int, default=4, help="Workers per component (default: 4)") args = parser.parse_args() exit(run_stress_test(args.duration, args.workers)) diff --git a/tests/test_3090ti_performance.py b/tests/test_3090ti_performance.py index d2e5481..1812592 100644 --- a/tests/test_3090ti_performance.py +++ b/tests/test_3090ti_performance.py @@ -13,6 +13,7 @@ - Target: 35.6 TFLOPS (90% of theoretical) - Minimum: 22 TFLOPS (must beat PyTorch baseline) """ + import os import time @@ -20,9 +21,7 @@ import pytest # Setup CUDA DLL path (if CUDA is installed) -cuda_path = os.environ.get( - "CUDA_PATH", r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" -) +cuda_path = os.environ.get("CUDA_PATH", r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4") cuda_bin = os.path.join(cuda_path, "bin") if os.path.isdir(cuda_bin): if cuda_bin not in os.environ.get("PATH", ""): @@ -256,7 +255,9 @@ def test_compute_bound_efficiency(self, check_3090ti): _, tflops = benchmark_matmul(8192, 8192, 8192, warmup=2, iterations=5) efficiency = tflops / RTX_3090TI_THEORETICAL_TFLOPS status = "PASS" if efficiency >= target_efficiency else "BELOW_TARGET" - print(f"\n8192x8192: {tflops:.2f} TFLOPS, efficiency: {efficiency*100:.1f}% (target: {target_efficiency*100:.0f}%) [{status}]") + print( + f"\n8192x8192: {tflops:.2f} TFLOPS, efficiency: {efficiency * 100:.1f}% (target: {target_efficiency * 100:.0f}%) [{status}]" + ) # Always pass - performance is informational def test_memory_bandwidth_utilization(self, check_3090ti): @@ -271,7 +272,9 @@ def test_memory_bandwidth_utilization(self, check_3090ti): bandwidth_gbps = bytes_transferred / time_sec / 1e9 status = "PASS" if bandwidth_gbps >= target_bw else "BELOW_TARGET" - print(f"\n{m}x{n}x{k} bandwidth: {bandwidth_gbps:.1f} GB/s (target: {target_bw}) [{status}]") + print( + f"\n{m}x{n}x{k} bandwidth: {bandwidth_gbps:.1f} GB/s (target: {target_bw}) [{status}]" + ) # Always pass - performance is informational diff --git a/tests/test_rust_admission_qos.py b/tests/test_rust_admission_qos.py index 89a257a..0947f3d 100644 --- a/tests/test_rust_admission_qos.py +++ b/tests/test_rust_admission_qos.py @@ -2,6 +2,7 @@ TDD Tests for v0.2.1 - Admission Control & QoS Policy Spec Tests written FIRST before implementation fixes. """ + import pytest # Skip all tests if Rust module not available @@ -92,10 +93,10 @@ def test_admission_stats(self): controller = rust.AdmissionController(config) stats = controller.stats() - assert hasattr(stats, 'used_memory') - assert hasattr(stats, 'used_bandwidth') - assert hasattr(stats, 'admitted_count') - assert hasattr(stats, 'rejected_count') + assert hasattr(stats, "used_memory") + assert hasattr(stats, "used_bandwidth") + assert hasattr(stats, "admitted_count") + assert hasattr(stats, "rejected_count") class TestQoSPolicySpec: @@ -103,7 +104,7 @@ class TestQoSPolicySpec: def test_qos_class_enum(self): """QoS classes should be Guaranteed, Burstable, BestEffort.""" - assert hasattr(rust, 'QosClass') + assert hasattr(rust, "QosClass") # Should be able to get class values guaranteed = rust.QosClass.Guaranteed @@ -217,8 +218,9 @@ def test_qos_with_partitioning(self): # Create inference partition pm.create_partition( - "inference", "Inference Partition", - rust.PartitionLimits().memory(4 * 1024 * 1024 * 1024).compute(0.5) + "inference", + "Inference Partition", + rust.PartitionLimits().memory(4 * 1024 * 1024 * 1024).compute(0.5), ) # Create QoS evaluator for the partition diff --git a/tests/test_tf32_api.py b/tests/test_tf32_api.py index 59773e7..41908bb 100644 --- a/tests/test_tf32_api.py +++ b/tests/test_tf32_api.py @@ -120,27 +120,27 @@ class TestDeviceCapabilities: def test_device_capabilities_exists(self): """Test that DeviceCapabilities class is available.""" - assert hasattr(gp, 'DeviceCapabilities') or hasattr(gp, 'get_device_capabilities') + assert hasattr(gp, "DeviceCapabilities") or hasattr(gp, "get_device_capabilities") def test_device_capabilities_tensorcore_field(self): """Test that DeviceCapabilities has tensorcore field.""" # Get capabilities for current device caps = gp.get_device_capabilities() - assert hasattr(caps, 'tensorcore') + assert hasattr(caps, "tensorcore") assert isinstance(caps.tensorcore, bool) def test_device_capabilities_sm_version(self): """Test that DeviceCapabilities has SM version info.""" caps = gp.get_device_capabilities() - assert hasattr(caps, 'sm_version') or hasattr(caps, 'compute_capability') + assert hasattr(caps, "sm_version") or hasattr(caps, "compute_capability") def test_tensorcore_requires_sm80(self): """Test that tensorcore is True only for SM >= 80.""" caps = gp.get_device_capabilities() - sm_version = getattr(caps, 'sm_version', None) or getattr(caps, 'compute_capability', 0) + sm_version = getattr(caps, "sm_version", None) or getattr(caps, "compute_capability", 0) if sm_version >= 80: # Ampere or newer should have tensor cores assert caps.tensorcore is True @@ -157,7 +157,8 @@ def test_kernel_type_exists(self): # This should be exposed via pygpukit._pygpukit_rust try: from pygpukit._pygpukit_rust import KernelType - assert hasattr(KernelType, 'Tf32Mma') or hasattr(KernelType, 'TF32_MMA') + + assert hasattr(KernelType, "Tf32Mma") or hasattr(KernelType, "TF32_MMA") except ImportError: # Rust module may not be built yet - skip pytest.skip("Rust module not available") @@ -166,7 +167,8 @@ def test_kernel_type_fp32_exists(self): """Test that FP32 kernel type exists.""" try: from pygpukit._pygpukit_rust import KernelType - assert hasattr(KernelType, 'Fp32Fma') or hasattr(KernelType, 'FP32_FMA') + + assert hasattr(KernelType, "Fp32Fma") or hasattr(KernelType, "FP32_FMA") except ImportError: pytest.skip("Rust module not available") diff --git a/tests/test_tf32_tensorcore.py b/tests/test_tf32_tensorcore.py index af5fcf9..ac0b7d7 100644 --- a/tests/test_tf32_tensorcore.py +++ b/tests/test_tf32_tensorcore.py @@ -14,6 +14,7 @@ - mma.sync.aligned.m16n8k8.row.col.tf32.tf32.f32 - 256 TFLOPS theoretical (TF32) """ + import os import time @@ -21,9 +22,7 @@ import pytest # Setup CUDA DLL path (if CUDA is installed) -cuda_path = os.environ.get( - "CUDA_PATH", r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" -) +cuda_path = os.environ.get("CUDA_PATH", r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4") cuda_bin = os.path.join(cuda_path, "bin") if os.path.isdir(cuda_bin): if cuda_bin not in os.environ.get("PATH", ""): @@ -75,7 +74,9 @@ def check_tensorcore(): if not has_tensorcore_support(): pytest.skip("TensorCore (SM >= 80) not available") props = native.get_device_properties(0) - print(f"\nGPU: {props.name} (SM {props.compute_capability_major}{props.compute_capability_minor})") + print( + f"\nGPU: {props.name} (SM {props.compute_capability_major}{props.compute_capability_minor})" + ) return props @@ -282,7 +283,9 @@ def test_tf32_faster_than_fp32(self, check_tensorcore): tf32_tflops = compute_tflops(m, n, k, tf32_time) status = "PASS" if tf32_tflops >= target else "BELOW_TARGET" - print(f"\nTF32 4096x4096: {tf32_tflops:.2f} TFLOPS (target: {target}, FP32 baseline: ~18) [{status}]") + print( + f"\nTF32 4096x4096: {tf32_tflops:.2f} TFLOPS (target: {target}, FP32 baseline: ~18) [{status}]" + ) # Always pass - performance is informational