From 9cf0128229497c1c6753bd909b0fe9afe06799d7 Mon Sep 17 00:00:00 2001 From: manoflearning <77jwk0724@gmail.com> Date: Wed, 17 Sep 2025 20:30:26 +0900 Subject: [PATCH 1/5] feat: introduce CUDA backend --- Cargo.lock | 92 +++++++++++ Cargo.toml | 2 + README.md | 12 +- src/backend/cuda.rs | 367 +++++++++++++++++++++++++++++++++++++++++++ src/backend/mod.rs | 9 +- src/py/mod.rs | 73 ++++++++- src/tests/backend.rs | 191 +++++++++++++++++++++- 7 files changed, 741 insertions(+), 5 deletions(-) create mode 100644 src/backend/cuda.rs diff --git a/Cargo.lock b/Cargo.lock index 602f351..381432c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,12 +28,23 @@ checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" name = "cranberry" version = "0.1.3" dependencies = [ + "cudarc", + "once_cell", "pyo3", "rand", "rand_distr", "thiserror", ] +[[package]] +name = "cudarc" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72ba848ae5c6f3cb36e71eab5f268763e3fabcabe3f7bc683e16f7fa3d46281e" +dependencies = [ + "libloading", +] + [[package]] name = "find-msvc-tools" version = "0.1.1" @@ -69,6 +80,16 @@ version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +[[package]] +name = "libloading" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" +dependencies = [ + "cfg-if", + "windows-targets", +] + [[package]] name = "libm" version = "0.2.15" @@ -307,6 +328,77 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-targets" +version = "0.53.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "zerocopy" version = "0.8.27" diff --git a/Cargo.toml b/Cargo.toml index ba51973..b644614 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,8 @@ pyo3 = { version = "0.24.1", features = ["extension-module"] } rand = { version = "0.8.5" } rand_distr = "0.4" thiserror = "1" +cudarc = { version = "0.17", features = ["driver", "nvrtc", "cuda-version-from-build-system"] } +once_cell = "1.19" [features] default = ["python"] diff --git a/README.md b/README.md index 02eef22..5ef93d8 100644 --- a/README.md +++ b/README.md @@ -15,13 +15,14 @@ A small deep learning framework in Rust and Python ## Overview -Cranberry is an educational project exploring how a tensor library, automatic differentiation, and a Rust-backed storage layer fit together. The Python front-end intentionally stays simple while the Rust extension supplies fast contiguous kernels and view manipulation utilities. Everything currently targets CPU and 32-bit floating point tensors. +Cranberry is an educational project exploring how a tensor library, automatic differentiation, and a Rust-backed storage layer fit together. The Python front-end intentionally stays simple while the Rust extension supplies fast contiguous kernels and view manipulation utilities. Everything targets 32-bit floating point tensors and now offers optional CUDA acceleration for pointwise kernels alongside the CPU path. ## Highlights - Python-first `Tensor` API backed by the `StorageView` PyO3 module. - Reverse-mode autograd with topological traversal, supporting gradient tracking through broadcasting and reshape/expand/permute transforms. - Contiguous CPU kernels for unary/binary ops plus sum/max reductions, with broadcasting handled in Python. +- Optional CUDA backend for contiguous unary and binary operations when an NVIDIA GPU and toolkit are available. - Basic neural-network building blocks (`nn.Linear`, `nn.ReLU`, `nn.Sequential`) and stochastic gradient descent in `optim.SGD`. - Visualization helpers for autograd graphs (`cranberry.features.visualize`) and an MNIST downloader with caching (`cranberry.features.datasets`). @@ -56,11 +57,12 @@ Cranberry is an educational project exploring how a tensor library, automatic di **Rust Extension** - `StorageView` exposes contiguous tensor storage, reshaping, expanding, permuting, and random fills. - CPU backend implements SIMD-accelerated unary/binary kernels and reduction routines. +- CUDA backend (via `cudarc`) mirrors the contiguous unary/binary kernels when a CUDA device is detected. - Views currently support up to rank-4 tensors; non-contiguous reshape/permute paths are under construction. ## Limitations & Work in Progress -- CPU-only; GPU/Metal backends are stubbed out. +- CUDA backend currently covers only contiguous unary/binary kernels; Metal remains stubbed out. - Autograd requires scalar losses and does not yet handle slicing/indexing/in-place mutations. - Views must be contiguous for most kernels; slicing and advanced indexing are not implemented. - Only `float32` tensors are supported; dtype promotion and mixed precision are future work. @@ -98,6 +100,12 @@ Optional extras: - `pip install -e .[datasets]` for download progress via `tqdm`. - `pip install -e .[all]` to include every extra. +## CUDA backend + +- Requires an NVIDIA driver and CUDA toolkit (NVRTC must be discoverable via `CUDA_HOME`, `CUDA_PATH`, or the default `/usr/local/cuda`). +- No separate `nvcc` build step is needed—the crate compiles its kernels at runtime using NVRTC. +- At runtime pass `device="cuda"` when creating tensors/storage; contiguous unary and binary ops will execute on the GPU and fall back with a runtime error if no device is present. + ## Quickstart ```python diff --git a/src/backend/cuda.rs b/src/backend/cuda.rs new file mode 100644 index 0000000..2c6e453 --- /dev/null +++ b/src/backend/cuda.rs @@ -0,0 +1,367 @@ +use std::{fmt, sync::Arc}; + +use cudarc::{ + driver::{self, CudaContext, CudaFunction, CudaSlice, CudaStream, LaunchConfig, PushKernelArg}, + nvrtc, +}; +use once_cell::sync::OnceCell; + +use super::{Backend, BackendError, BackendResult, BinaryOp, UnaryOp}; +use crate::core::{ + storage::StorageInner, + view::{contiguous_strides, View}, +}; +use crate::device::Device; + +/// CUDA kernels expressed as C strings compiled at runtime with NVRTC. +const KERNEL_SOURCE: &str = r#" +// NVRTC treats this source as C++, so we pull in the headers that define the +// math intrinsics (`sqrtf`, `expf`, `logf`) and the `size_t` type that we use for +// element counts. +#include +#include + +// Every kernel follows the exact same thread-indexing pattern: each CUDA thread +// owns a unique flat index into the logical tensor. The math is identical for all +// kernels, so we declare a helper macro to make the generated PTX a little smaller +// and, just as importantly for new CUDA developers, to keep the thread math readable. +#define GLOBAL_INDEX() (blockIdx.x * blockDim.x + threadIdx.x) + +extern "C" __global__ void unary_neg(const float* __restrict__ inp, + float* __restrict__ out, + size_t n) { + // Convert the 3-D CUDA launch geometry into a single linear index. + const size_t idx = GLOBAL_INDEX(); + // Guard against the final, partially-filled block writing past the end. + if (idx < n) { + // Apply the unary operation in-place for this linear element. + out[idx] = -inp[idx]; + } +} + +extern "C" __global__ void unary_sqrt(const float* __restrict__ inp, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = sqrtf(inp[idx]); + } +} + +extern "C" __global__ void unary_exp(const float* __restrict__ inp, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = expf(inp[idx]); + } +} + +extern "C" __global__ void unary_log(const float* __restrict__ inp, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = logf(inp[idx]); + } +} + +extern "C" __global__ void unary_relu(const float* __restrict__ inp, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + const float v = inp[idx]; + out[idx] = v > 0.0f ? v : 0.0f; + } +} + +extern "C" __global__ void binary_add(const float* __restrict__ lhs, + const float* __restrict__ rhs, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = lhs[idx] + rhs[idx]; + } +} + +extern "C" __global__ void binary_sub(const float* __restrict__ lhs, + const float* __restrict__ rhs, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = lhs[idx] - rhs[idx]; + } +} + +extern "C" __global__ void binary_mul(const float* __restrict__ lhs, + const float* __restrict__ rhs, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = lhs[idx] * rhs[idx]; + } +} + +extern "C" __global__ void binary_div(const float* __restrict__ lhs, + const float* __restrict__ rhs, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = lhs[idx] / rhs[idx]; + } +} +"#; + +/// Lazily constructed global CUDA backend. +/// +/// We initialize everything on first use so that simply importing the Python +/// extension does not require the user to have a CUDA device installed. +static CUDA_BACKEND: OnceCell> = OnceCell::new(); + +fn map_driver_err(err: driver::result::DriverError) -> BackendError { + BackendError::Cuda(err.to_string()) +} + +struct UnaryKernels { + neg: CudaFunction, + sqrt: CudaFunction, + exp: CudaFunction, + log: CudaFunction, + relu: CudaFunction, +} + +struct BinaryKernels { + add: CudaFunction, + sub: CudaFunction, + mul: CudaFunction, + div: CudaFunction, +} + +/// Holds the CUDA context, primary stream, and all kernel handles shared by GPU ops. +pub struct CudaBackend { + stream: Arc, + unary: UnaryKernels, + binary: BinaryKernels, +} + +impl fmt::Debug for CudaBackend { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CudaBackend") + .field("device", &Device::Cuda) + .finish() + } +} + +impl CudaBackend { + /// Returns the process-wide CUDA backend instance, creating it on first use. + pub fn global() -> BackendResult<&'static CudaBackend> { + CUDA_BACKEND + .get_or_init(|| init_backend()) + .as_ref() + .map_err(Clone::clone) + } +} + +fn init_backend() -> Result { + let device_count = CudaContext::device_count().map_err(map_driver_err)?; + if device_count <= 0 { + return Err(BackendError::CudaUnavailable( + "no CUDA devices detected".to_string(), + )); + } + + // Acquire the primary context for device 0 and keep its default stream alive. + // We intentionally pick device 0 for now; extending this to support multiple + // GPUs just means plumbing an ordinal through the Python API and reusing the + // exact same initialization procedure. + let ctx = CudaContext::new(0).map_err(map_driver_err)?; + let stream = ctx.default_stream(); + + // Compile the small collection of pointwise kernels. NVRTC gives us a way to + // embed plain CUDA C++ as a string and turn it into PTX at runtime, which keeps + // the crate self-contained and removes the need for a separate `nvcc` build + // step in the user environment. + let ptx = + nvrtc::compile_ptx(KERNEL_SOURCE).map_err(|err| BackendError::Cuda(err.to_string()))?; + let module = ctx.load_module(ptx).map_err(map_driver_err)?; + + let unary = UnaryKernels { + neg: module.load_function("unary_neg").map_err(map_driver_err)?, + sqrt: module.load_function("unary_sqrt").map_err(map_driver_err)?, + exp: module.load_function("unary_exp").map_err(map_driver_err)?, + log: module.load_function("unary_log").map_err(map_driver_err)?, + relu: module.load_function("unary_relu").map_err(map_driver_err)?, + }; + + let binary = BinaryKernels { + add: module.load_function("binary_add").map_err(map_driver_err)?, + sub: module.load_function("binary_sub").map_err(map_driver_err)?, + mul: module.load_function("binary_mul").map_err(map_driver_err)?, + div: module.load_function("binary_div").map_err(map_driver_err)?, + }; + + Ok(CudaBackend { + stream, + unary, + binary, + }) +} + +impl CudaBackend { + fn ensure_contiguous(view: &View) -> BackendResult<()> { + if view.is_contiguous() { + Ok(()) + } else { + // For now we only expose element-wise kernels, so we require buffers + // to be densely packed in memory. The Python binding already exposes + // `.contiguous()` to users, so we surface the same error that the CPU + // backend reports in the mismatch case. + Err(BackendError::NotContiguous) + } + } + + fn to_device_slice(&self, view: &View) -> BackendResult> { + // The storage currently lives in host memory. Before we can launch a + // kernel we need to stage that data on the device. `memcpy_stod` + // allocates the device buffer for us and schedules the host-to-device + // transfer onto the stream we cached during initialization. + let slice = view.inner.as_slice(view.offset, view.numel()); + self.stream.memcpy_stod(slice).map_err(map_driver_err) + } + + fn from_device_slice(&self, slice: CudaSlice, shape: &[usize]) -> BackendResult { + // The kernel launch is asynchronous with respect to the host. We force + // completion here so that when we copy the results back we are guaranteed + // to see the writes that the GPU just produced. + self.stream.synchronize().map_err(map_driver_err)?; + // Bring the freshly computed buffer back to the host. The helper allocates + // a `Vec` of the right size and performs a device-to-host copy. + let host = self.stream.memcpy_dtov(&slice).map_err(map_driver_err)?; + // Wrap the vector in our existing `StorageInner` abstraction so the rest + // of the codebase can keep treating the buffer like any other contiguous + // allocation. We still tag it as `Device::Cuda` to preserve the logical + // device placement that the caller requested. + let inner = StorageInner::from_vec(host, Device::Cuda); + Ok(View { + inner: Arc::new(inner), + offset: 0, + shape: shape.to_vec(), + strides: contiguous_strides(shape), + }) + } +} + +impl Backend for CudaBackend { + fn unary(&self, op: UnaryOp, a: &View) -> BackendResult { + Self::ensure_contiguous(a)?; + + // Upload the source tensor to device memory so the kernel can consume it. + // We do this lazily on every invocation; a production implementation would + // want to keep allocations around, but this keeps the control flow easy to + // follow while still exercising the GPU cores. + let device_in = self.to_device_slice(a)?; + let mut device_out = self + .stream + .alloc_zeros::(a.numel()) + .map_err(map_driver_err)?; + // `alloc_zeros` both reserves device memory and memset's it to 0. The + // zero-fill is not strictly required for our kernels, but it gives us a + // well-defined value in case the GPU ever reads an out-of-bounds index + // (which would indicate a bug elsewhere). + + // Precompute the launch geometry once per kernel dispatch. The helper + // chooses a 1024-thread block size and enough blocks to touch every + // element, which works well for simple pointwise kernels. + let cfg = LaunchConfig::for_num_elems(a.numel() as u32); + let len = a.numel(); + + // The launch builder consumes immutable or mutable references to CUDA + // slices. We build the argument list in the exact same way for every + // unary op and only vary the function handle. + let mut launch = |func: &CudaFunction| -> BackendResult<()> { + // Launching work on a CUDA stream is inherently unsafe because the + // compiler cannot prove that the raw pointers we pass stay valid for + // the duration of the kernel. Every argument we push here is owned by + // a `CudaSlice`, so the buffers remain alive until the launch has + // completed. + unsafe { + self.stream + .launch_builder(func) + .arg(&device_in) + .arg(&mut device_out) + .arg(&len) + .launch(cfg) + } + .map(|_| ()) + .map_err(map_driver_err) + }; + + match op { + UnaryOp::Neg => launch(&self.unary.neg)?, + UnaryOp::Sqrt => launch(&self.unary.sqrt)?, + UnaryOp::Exp => launch(&self.unary.exp)?, + UnaryOp::Log => launch(&self.unary.log)?, + UnaryOp::Relu => launch(&self.unary.relu)?, + } + + self.from_device_slice(device_out, &a.shape) + } + + fn binary(&self, op: BinaryOp, a: &View, b: &View) -> BackendResult { + Self::ensure_contiguous(a)?; + Self::ensure_contiguous(b)?; + + if a.numel() != b.numel() { + return Err(BackendError::ShapeMismatch); + } + + // Mirror the upload path we used for the unary kernels, but stage both + // operands on the device. + let device_a = self.to_device_slice(a)?; + let device_b = self.to_device_slice(b)?; + let mut device_out = self + .stream + .alloc_zeros::(a.numel()) + .map_err(map_driver_err)?; + // The output buffer is shared across all binary launches to avoid + // allocating multiple temporaries when we chain different kernels for the + // same call. + + // Again, a single helper computes the dispatch geometry. + let cfg = LaunchConfig::for_num_elems(a.numel() as u32); + let len = a.numel(); + + // Build-and-launch pipeline that mirrors the unary helper, this time + // attaching both input buffers. + let mut launch = |func: &CudaFunction| -> BackendResult<()> { + // Same safety story as the unary path: we guarantee that the + // `CudaSlice` lifetimes outlive the kernel execution so the raw + // pointers remain valid for the GPU. + unsafe { + self.stream + .launch_builder(func) + .arg(&device_a) + .arg(&device_b) + .arg(&mut device_out) + .arg(&len) + .launch(cfg) + } + .map(|_| ()) + .map_err(map_driver_err) + }; + + match op { + BinaryOp::Add => launch(&self.binary.add)?, + BinaryOp::Sub => launch(&self.binary.sub)?, + BinaryOp::Mul => launch(&self.binary.mul)?, + BinaryOp::Div => launch(&self.binary.div)?, + } + + self.from_device_slice(device_out, &a.shape) + } +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 18ad865..993c5f3 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -19,12 +19,16 @@ pub enum BinaryOp { Div, } -#[derive(Debug, Error)] +#[derive(Debug, Error, Clone)] pub enum BackendError { #[error("operation requires contiguous views")] NotContiguous, #[error("shape mismatch")] ShapeMismatch, + #[error("cuda runtime error: {0}")] + Cuda(String), + #[error("cuda device unavailable: {0}")] + CudaUnavailable(String), } pub type BackendResult = Result; @@ -41,8 +45,11 @@ use std::sync::Arc; use crate::core::storage::StorageInner; use crate::device::Device; +mod cuda; mod kernels_simd; +pub use cuda::CudaBackend; + pub struct CpuBackend; impl Backend for CpuBackend { diff --git a/src/py/mod.rs b/src/py/mod.rs index 8550f94..bf7510b 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use pyo3::prelude::*; -use crate::backend::{Backend, BinaryOp, CpuBackend, UnaryOp}; +use crate::backend::{Backend, BackendError, BinaryOp, CpuBackend, CudaBackend, UnaryOp}; use crate::core::{ reduce::{self, AxisError}, storage::StorageInner, @@ -21,6 +21,14 @@ fn axis_error_to_py(err: AxisError) -> PyErr { } } +fn backend_err_to_py(err: BackendError) -> PyErr { + pyo3::exceptions::PyRuntimeError::new_err(err.to_string()) +} + +fn cuda_backend_or_pyerr() -> Result<&'static CudaBackend, PyErr> { + CudaBackend::global().map_err(backend_err_to_py) +} + #[cfg(test)] mod tests { use super::*; @@ -286,6 +294,13 @@ impl StorageView { .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; Ok(StorageView { view: out }) } + crate::device::Device::Cuda => { + let backend = cuda_backend_or_pyerr()?; + let out = backend + .unary(UnaryOp::Neg, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) + } _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( "unary ops not implemented for this device", )), @@ -300,6 +315,13 @@ impl StorageView { .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; Ok(StorageView { view: out }) } + crate::device::Device::Cuda => { + let backend = cuda_backend_or_pyerr()?; + let out = backend + .unary(UnaryOp::Sqrt, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) + } _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( "unary ops not implemented for this device", )), @@ -314,6 +336,13 @@ impl StorageView { .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; Ok(StorageView { view: out }) } + crate::device::Device::Cuda => { + let backend = cuda_backend_or_pyerr()?; + let out = backend + .unary(UnaryOp::Relu, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) + } _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( "unary ops not implemented for this device", )), @@ -328,6 +357,13 @@ impl StorageView { .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; Ok(StorageView { view: out }) } + crate::device::Device::Cuda => { + let backend = cuda_backend_or_pyerr()?; + let out = backend + .unary(UnaryOp::Exp, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) + } _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( "unary ops not implemented for this device", )), @@ -342,6 +378,13 @@ impl StorageView { .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; Ok(StorageView { view: out }) } + crate::device::Device::Cuda => { + let backend = cuda_backend_or_pyerr()?; + let out = backend + .unary(UnaryOp::Log, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) + } _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( "unary ops not implemented for this device", )), @@ -360,6 +403,13 @@ impl StorageView { .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; Ok(StorageView { view: out }) } + crate::device::Device::Cuda => { + let backend = cuda_backend_or_pyerr()?; + let out = backend + .binary(BinaryOp::Add, &self.view, &other.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) + } _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( "binary ops not implemented for this device", )), @@ -377,6 +427,13 @@ impl StorageView { .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; Ok(StorageView { view: out }) } + crate::device::Device::Cuda => { + let backend = cuda_backend_or_pyerr()?; + let out = backend + .binary(BinaryOp::Sub, &self.view, &other.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) + } _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( "binary ops not implemented for this device", )), @@ -394,6 +451,13 @@ impl StorageView { .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; Ok(StorageView { view: out }) } + crate::device::Device::Cuda => { + let backend = cuda_backend_or_pyerr()?; + let out = backend + .binary(BinaryOp::Mul, &self.view, &other.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) + } _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( "binary ops not implemented for this device", )), @@ -411,6 +475,13 @@ impl StorageView { .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; Ok(StorageView { view: out }) } + crate::device::Device::Cuda => { + let backend = cuda_backend_or_pyerr()?; + let out = backend + .binary(BinaryOp::Div, &self.view, &other.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) + } _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( "binary ops not implemented for this device", )), diff --git a/src/tests/backend.rs b/src/tests/backend.rs index 7d3200d..9c8e79c 100644 --- a/src/tests/backend.rs +++ b/src/tests/backend.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::backend::{Backend, BinaryOp, CpuBackend, UnaryOp}; +use crate::backend::{Backend, BinaryOp, CpuBackend, CudaBackend, UnaryOp}; use crate::core::{storage::StorageInner, view::View}; use crate::device::Device; @@ -9,6 +9,78 @@ fn vec_view(v: Vec) -> View { View::from_inner_1d(inner) } +fn cuda_view(v: Vec) -> View { + let inner = Arc::new(StorageInner::from_vec(v, Device::Cuda)); + View::from_inner_1d(inner) +} + +fn cuda_backend_or_skip(test_name: &str) -> Option<&'static CudaBackend> { + match CudaBackend::global() { + Ok(be) => Some(be), + Err(err) => { + eprintln!("skipping {test_name}: {err}"); + None + } + } +} + +fn assert_close(got: f32, expected: f32, tol: f32, idx: usize) { + let diff = (got - expected).abs(); + assert!( + diff <= tol, + "value mismatch at index {idx}: got {got}, expected {expected}, diff {diff} > tol {tol}" + ); +} + +fn assert_vec_close(actual: &[f32], expected: &[f32], tol: f32) { + assert_eq!(actual.len(), expected.len(), "length mismatch"); + for (idx, (got, exp)) in actual.iter().zip(expected.iter()).enumerate() { + assert_close(*got, *exp, tol, idx); + } +} + +fn run_cuda_unary_case(name: &str, op: UnaryOp, input: &[f32], reference: F) +where + F: Fn(f32) -> f32, +{ + let backend = match cuda_backend_or_skip(name) { + Some(be) => be, + None => return, + }; + + let view = cuda_view(input.to_vec()); + let out = backend + .unary(op, &view) + .unwrap_or_else(|err| panic!("{} failed: {}", name, err)); + let actual = out.inner.as_slice(out.offset, out.numel()); + let expected: Vec = input.iter().copied().map(reference).collect(); + assert_vec_close(actual, expected.as_slice(), 1e-5); +} + +fn run_cuda_binary_case(name: &str, op: BinaryOp, lhs: &[f32], rhs: &[f32], reference: F) +where + F: Fn(f32, f32) -> f32, +{ + let backend = match cuda_backend_or_skip(name) { + Some(be) => be, + None => return, + }; + + let a = cuda_view(lhs.to_vec()); + let b = cuda_view(rhs.to_vec()); + let out = backend + .binary(op, &a, &b) + .unwrap_or_else(|err| panic!("{} failed: {}", name, err)); + let actual = out.inner.as_slice(out.offset, out.numel()); + let expected: Vec = lhs + .iter() + .copied() + .zip(rhs.iter().copied()) + .map(|(x, y)| reference(x, y)) + .collect(); + assert_vec_close(actual, expected.as_slice(), 1e-5); +} + #[test] fn unary_neg_matches_scalar() { let a = vec_view(vec![1.0, -2.0, 3.0, -4.5, 0.25]); @@ -100,3 +172,120 @@ fn binary_not_contiguous_error() { let err = be.binary(BinaryOp::Add, &a_nc, &b).unwrap_err(); assert!(matches!(err, crate::backend::BackendError::NotContiguous)); } + +#[test] +fn cuda_unary_neg_matches_cpu_when_device_available() { + run_cuda_unary_case( + "cuda_unary_neg", + UnaryOp::Neg, + &[1.0, -3.0, 0.5, 7.25], + |x| -x, + ); +} + +#[test] +fn cuda_unary_sqrt_matches_cpu_when_device_available() { + run_cuda_unary_case( + "cuda_unary_sqrt", + UnaryOp::Sqrt, + &[0.0, 0.25, 1.0, 4.0, 9.0], + |x| x.sqrt(), + ); +} + +#[test] +fn cuda_unary_exp_matches_cpu_when_device_available() { + run_cuda_unary_case( + "cuda_unary_exp", + UnaryOp::Exp, + &[1.0, -2.5, 3.25, -4.75, 0.0], + |x| x.exp(), + ); +} + +#[test] +fn cuda_unary_log_matches_cpu_when_device_available() { + run_cuda_unary_case( + "cuda_unary_log", + UnaryOp::Log, + &[0.25, 1.0, 2.5, 10.0], + |x| x.ln(), + ); +} + +#[test] +fn cuda_unary_relu_matches_cpu_when_device_available() { + run_cuda_unary_case( + "cuda_unary_relu", + UnaryOp::Relu, + &[-3.0, -0.5, 0.0, 0.5, 2.0, 5.0], + |x| x.max(0.0), + ); +} + +#[test] +fn cuda_binary_add_matches_cpu_when_device_available() { + let lhs: Vec = (0..64).map(|i| i as f32 * 0.5).collect(); + let rhs: Vec = (0..64).map(|i| i as f32 * -0.25).collect(); + run_cuda_binary_case("cuda_binary_add", BinaryOp::Add, &lhs, &rhs, |x, y| x + y); +} + +#[test] +fn cuda_binary_sub_matches_cpu_when_device_available() { + let lhs: Vec = (0..64).map(|i| i as f32 * 0.5).collect(); + let rhs: Vec = (0..64).map(|i| i as f32 * -0.25).collect(); + run_cuda_binary_case("cuda_binary_sub", BinaryOp::Sub, &lhs, &rhs, |x, y| x - y); +} + +#[test] +fn cuda_binary_mul_matches_cpu_when_device_available() { + let lhs: Vec = (0..64).map(|i| i as f32 * 0.5).collect(); + let rhs: Vec = (0..64).map(|i| (i as f32 + 1.0) * 0.1).collect(); + run_cuda_binary_case("cuda_binary_mul", BinaryOp::Mul, &lhs, &rhs, |x, y| x * y); +} + +#[test] +fn cuda_binary_div_matches_cpu_when_device_available() { + let lhs: Vec = (1..65).map(|i| i as f32).collect(); + let rhs: Vec = (1..65).map(|i| (i as f32) * 0.5 + 1.0).collect(); + run_cuda_binary_case("cuda_binary_div", BinaryOp::Div, &lhs, &rhs, |x, y| x / y); +} + +#[test] +fn cuda_unary_not_contiguous_error_when_device_available() { + let backend = match cuda_backend_or_skip("cuda_unary_not_contiguous_error") { + Some(be) => be, + None => return, + }; + let view = cuda_view((0..12).map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); + let non_contig = view.permute(&[1, 0]); + let err = backend.unary(UnaryOp::Neg, &non_contig).unwrap_err(); + assert!(matches!(err, crate::backend::BackendError::NotContiguous)); +} + +#[test] +fn cuda_binary_not_contiguous_error_when_device_available() { + let backend = match cuda_backend_or_skip("cuda_binary_not_contiguous_error") { + Some(be) => be, + None => return, + }; + let view_a = cuda_view((0..12).map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); + let view_b = cuda_view((0..12).rev().map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); + let non_contig = view_a.permute(&[1, 0]); + let err = backend + .binary(BinaryOp::Add, &non_contig, &view_b) + .unwrap_err(); + assert!(matches!(err, crate::backend::BackendError::NotContiguous)); +} + +#[test] +fn cuda_binary_shape_mismatch_error_when_device_available() { + let backend = match cuda_backend_or_skip("cuda_binary_shape_mismatch_error") { + Some(be) => be, + None => return, + }; + let a = cuda_view(vec![1.0, 2.0, 3.0]); + let b = cuda_view(vec![4.0, 5.0]); + let err = backend.binary(BinaryOp::Add, &a, &b).unwrap_err(); + assert!(matches!(err, crate::backend::BackendError::ShapeMismatch)); +} From a24d84b06181fdbde2dc38326570c8de94da7bee Mon Sep 17 00:00:00 2001 From: manoflearning <77jwk0724@gmail.com> Date: Wed, 17 Sep 2025 20:32:58 +0900 Subject: [PATCH 2/5] fix: cargo clippy warning --- src/backend/cuda.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/backend/cuda.rs b/src/backend/cuda.rs index 2c6e453..c20c04c 100644 --- a/src/backend/cuda.rs +++ b/src/backend/cuda.rs @@ -161,7 +161,7 @@ impl CudaBackend { /// Returns the process-wide CUDA backend instance, creating it on first use. pub fn global() -> BackendResult<&'static CudaBackend> { CUDA_BACKEND - .get_or_init(|| init_backend()) + .get_or_init(init_backend) .as_ref() .map_err(Clone::clone) } @@ -234,7 +234,11 @@ impl CudaBackend { self.stream.memcpy_stod(slice).map_err(map_driver_err) } - fn from_device_slice(&self, slice: CudaSlice, shape: &[usize]) -> BackendResult { + fn view_from_device_slice( + &self, + slice: CudaSlice, + shape: &[usize], + ) -> BackendResult { // The kernel launch is asynchronous with respect to the host. We force // completion here so that when we copy the results back we are guaranteed // to see the writes that the GPU just produced. @@ -309,7 +313,7 @@ impl Backend for CudaBackend { UnaryOp::Relu => launch(&self.unary.relu)?, } - self.from_device_slice(device_out, &a.shape) + self.view_from_device_slice(device_out, &a.shape) } fn binary(&self, op: BinaryOp, a: &View, b: &View) -> BackendResult { @@ -362,6 +366,6 @@ impl Backend for CudaBackend { BinaryOp::Div => launch(&self.binary.div)?, } - self.from_device_slice(device_out, &a.shape) + self.view_from_device_slice(device_out, &a.shape) } } From fd3767b6a62a24de8756f6560e388667778a8209 Mon Sep 17 00:00:00 2001 From: manoflearning <77jwk0724@gmail.com> Date: Thu, 18 Sep 2025 00:20:35 +0900 Subject: [PATCH 3/5] refactor --- src/backend/{kernels_simd.rs => cpu.rs} | 99 ++++-- src/backend/cuda.rs | 388 +++++++++++++----------- src/backend/mod.rs | 92 ++---- src/backend/registry.rs | 30 ++ src/py/mod.rs | 238 ++++----------- src/tests/backend.rs | 41 ++- 6 files changed, 444 insertions(+), 444 deletions(-) rename src/backend/{kernels_simd.rs => cpu.rs} (65%) create mode 100644 src/backend/registry.rs diff --git a/src/backend/kernels_simd.rs b/src/backend/cpu.rs similarity index 65% rename from src/backend/kernels_simd.rs rename to src/backend/cpu.rs index 439cd3d..e301d46 100644 --- a/src/backend/kernels_simd.rs +++ b/src/backend/cpu.rs @@ -1,14 +1,70 @@ use std::simd::num::SimdFloat; use std::simd::{f32x64, StdFloat}; +use super::{ + require_contiguous, require_same_numel, view_from_storage, Backend, BackendResult, BinaryOp, + UnaryOp, +}; +use crate::core::{storage::StorageInner, view::View}; +use crate::device::Device; + const CHUNK_SIZE: usize = 64; -pub mod unary_ops { +pub struct CpuBackend; + +impl CpuBackend { + fn apply_unary(op: UnaryOp, input: &[f32], output: &mut [f32]) { + unary::apply(op, input, output); + } + + fn apply_binary(op: BinaryOp, lhs: &[f32], rhs: &[f32], output: &mut [f32]) { + binary::apply(op, lhs, rhs, output); + } +} + +impl Backend for CpuBackend { + fn unary(&self, op: UnaryOp, a: &View) -> BackendResult { + require_contiguous(a)?; + let input = a.inner.as_slice(a.offset, a.numel()); + let mut storage = StorageInner::new_full(0.0, a.numel(), Device::Cpu); + { + let output = storage.as_mut_slice(0, a.numel()); + Self::apply_unary(op, input, output); + } + Ok(view_from_storage(storage, &a.shape)) + } + + fn binary(&self, op: BinaryOp, a: &View, b: &View) -> BackendResult { + require_contiguous(a)?; + require_contiguous(b)?; + require_same_numel(a, b)?; + let lhs = a.inner.as_slice(a.offset, a.numel()); + let rhs = b.inner.as_slice(b.offset, b.numel()); + let mut storage = StorageInner::new_full(0.0, a.numel(), Device::Cpu); + { + let output = storage.as_mut_slice(0, a.numel()); + Self::apply_binary(op, lhs, rhs, output); + } + Ok(view_from_storage(storage, &a.shape)) + } +} + +pub(crate) mod unary { use super::*; use std::ops::Neg; - pub fn neg(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); + pub fn apply(op: UnaryOp, input: &[f32], output: &mut [f32]) { + assert_eq!(input.len(), output.len()); + match op { + UnaryOp::Neg => neg(input, output), + UnaryOp::Sqrt => sqrt(input, output), + UnaryOp::Exp => exp(input, output), + UnaryOp::Log => log(input, output), + UnaryOp::Relu => relu(input, output), + } + } + + fn neg(a: &[f32], b: &mut [f32]) { let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); a_main @@ -21,8 +77,7 @@ pub mod unary_ops { .for_each(|(a, b)| *b = -a); } - pub fn sqrt(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); + fn sqrt(a: &[f32], b: &mut [f32]) { let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); a_main @@ -35,8 +90,7 @@ pub mod unary_ops { .for_each(|(a, b)| *b = a.sqrt()); } - pub fn relu(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); + fn relu(a: &[f32], b: &mut [f32]) { let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); let zero = f32x64::splat(0.0); @@ -50,8 +104,7 @@ pub mod unary_ops { .for_each(|(a, b)| *b = a.max(0.0)); } - pub fn exp(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); + fn exp(a: &[f32], b: &mut [f32]) { let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); a_main @@ -64,8 +117,7 @@ pub mod unary_ops { .for_each(|(a, b)| *b = a.exp()); } - pub fn log(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); + fn log(a: &[f32], b: &mut [f32]) { let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); a_main @@ -79,12 +131,22 @@ pub mod unary_ops { } } -pub mod binary_ops { +pub(crate) mod binary { use super::*; use std::ops::{Add, Div, Mul, Sub}; - pub fn add(a: &[f32], b: &[f32], c: &mut [f32]) { - assert!(a.len() == b.len() && b.len() == c.len()); + pub fn apply(op: BinaryOp, lhs: &[f32], rhs: &[f32], output: &mut [f32]) { + assert_eq!(lhs.len(), rhs.len()); + assert_eq!(rhs.len(), output.len()); + match op { + BinaryOp::Add => add(lhs, rhs, output), + BinaryOp::Sub => sub(lhs, rhs, output), + BinaryOp::Mul => mul(lhs, rhs, output), + BinaryOp::Div => div(lhs, rhs, output), + } + } + + fn add(a: &[f32], b: &[f32], c: &mut [f32]) { let main = a.len() - a.len() % CHUNK_SIZE; let (a_main, a_rem) = a.split_at(main); let (b_main, b_rem) = b.split_at(main); @@ -105,8 +167,7 @@ pub mod binary_ops { .for_each(|((a, b), c)| *c = a + b); } - pub fn sub(a: &[f32], b: &[f32], c: &mut [f32]) { - assert!(a.len() == b.len() && b.len() == c.len()); + fn sub(a: &[f32], b: &[f32], c: &mut [f32]) { let main = a.len() - a.len() % CHUNK_SIZE; let (a_main, a_rem) = a.split_at(main); let (b_main, b_rem) = b.split_at(main); @@ -127,8 +188,7 @@ pub mod binary_ops { .for_each(|((a, b), c)| *c = a - b); } - pub fn mul(a: &[f32], b: &[f32], c: &mut [f32]) { - assert!(a.len() == b.len() && b.len() == c.len()); + fn mul(a: &[f32], b: &[f32], c: &mut [f32]) { let main = a.len() - a.len() % CHUNK_SIZE; let (a_main, a_rem) = a.split_at(main); let (b_main, b_rem) = b.split_at(main); @@ -149,8 +209,7 @@ pub mod binary_ops { .for_each(|((a, b), c)| *c = a * b); } - pub fn div(a: &[f32], b: &[f32], c: &mut [f32]) { - assert!(a.len() == b.len() && b.len() == c.len()); + fn div(a: &[f32], b: &[f32], c: &mut [f32]) { let main = a.len() - a.len() % CHUNK_SIZE; let (a_main, a_rem) = a.split_at(main); let (b_main, b_rem) = b.split_at(main); diff --git a/src/backend/cuda.rs b/src/backend/cuda.rs index c20c04c..0ae4159 100644 --- a/src/backend/cuda.rs +++ b/src/backend/cuda.rs @@ -1,20 +1,175 @@ use std::{fmt, sync::Arc}; -use cudarc::{ - driver::{self, CudaContext, CudaFunction, CudaSlice, CudaStream, LaunchConfig, PushKernelArg}, - nvrtc, -}; -use once_cell::sync::OnceCell; +use cudarc::driver::{self, CudaContext, CudaSlice, CudaStream, LaunchConfig}; -use super::{Backend, BackendError, BackendResult, BinaryOp, UnaryOp}; -use crate::core::{ - storage::StorageInner, - view::{contiguous_strides, View}, +use super::{ + require_contiguous, require_same_numel, view_from_storage, Backend, BackendError, + BackendResult, BinaryOp, UnaryOp, }; +use crate::core::{storage::StorageInner, view::View}; use crate::device::Device; -/// CUDA kernels expressed as C strings compiled at runtime with NVRTC. -const KERNEL_SOURCE: &str = r#" +fn map_driver_err(err: driver::result::DriverError) -> BackendError { + BackendError::Cuda(err.to_string()) +} + +mod kernels { + use super::{map_driver_err, BackendError, BackendResult, BinaryOp, UnaryOp}; + use cudarc::driver::{ + CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream, LaunchConfig, PushKernelArg, + }; + use cudarc::nvrtc; + use std::sync::Arc; + + pub(crate) struct Kernels { + unary: unary::Module, + binary: binary::Module, + } + + impl Kernels { + pub fn compile(ctx: &Arc) -> Result { + // Produce PTX from the inline CUDA source. NVRTC jit-compiles at runtime, + // which keeps the crate self-contained and avoids a separate `nvcc` step. + let ptx = + nvrtc::compile_ptx(SOURCE).map_err(|err| BackendError::Cuda(err.to_string()))?; + let module = ctx.load_module(ptx).map_err(map_driver_err)?; + Ok(Self { + unary: unary::Module::load(&module)?, + binary: binary::Module::load(&module)?, + }) + } + + pub fn launch_unary( + &self, + stream: &Arc, + op: UnaryOp, + input: &CudaSlice, + output: &mut CudaSlice, + len: usize, + cfg: LaunchConfig, + ) -> BackendResult<()> { + self.unary.launch(stream, op, input, output, len, cfg) + } + + pub fn launch_binary( + &self, + stream: &Arc, + op: BinaryOp, + lhs: &CudaSlice, + rhs: &CudaSlice, + output: &mut CudaSlice, + len: usize, + cfg: LaunchConfig, + ) -> BackendResult<()> { + self.binary.launch(stream, op, lhs, rhs, output, len, cfg) + } + } + + mod unary { + use super::*; + + pub struct Module { + neg: CudaFunction, + sqrt: CudaFunction, + exp: CudaFunction, + log: CudaFunction, + relu: CudaFunction, + } + + impl Module { + pub fn load(module: &Arc) -> Result { + Ok(Self { + neg: module.load_function("unary_neg").map_err(map_driver_err)?, + sqrt: module.load_function("unary_sqrt").map_err(map_driver_err)?, + exp: module.load_function("unary_exp").map_err(map_driver_err)?, + log: module.load_function("unary_log").map_err(map_driver_err)?, + relu: module.load_function("unary_relu").map_err(map_driver_err)?, + }) + } + + pub fn launch( + &self, + stream: &Arc, + op: UnaryOp, + input: &CudaSlice, + output: &mut CudaSlice, + len: usize, + cfg: LaunchConfig, + ) -> BackendResult<()> { + let func = match op { + UnaryOp::Neg => &self.neg, + UnaryOp::Sqrt => &self.sqrt, + UnaryOp::Exp => &self.exp, + UnaryOp::Log => &self.log, + UnaryOp::Relu => &self.relu, + }; + unsafe { + stream + .launch_builder(func) + .arg(input) + .arg(output) + .arg(&len) + .launch(cfg) + } + .map(|_| ()) + .map_err(map_driver_err) + } + } + } + + mod binary { + use super::*; + + pub struct Module { + add: CudaFunction, + sub: CudaFunction, + mul: CudaFunction, + div: CudaFunction, + } + + impl Module { + pub fn load(module: &Arc) -> Result { + Ok(Self { + add: module.load_function("binary_add").map_err(map_driver_err)?, + sub: module.load_function("binary_sub").map_err(map_driver_err)?, + mul: module.load_function("binary_mul").map_err(map_driver_err)?, + div: module.load_function("binary_div").map_err(map_driver_err)?, + }) + } + + pub fn launch( + &self, + stream: &Arc, + op: BinaryOp, + lhs: &CudaSlice, + rhs: &CudaSlice, + output: &mut CudaSlice, + len: usize, + cfg: LaunchConfig, + ) -> BackendResult<()> { + let func = match op { + BinaryOp::Add => &self.add, + BinaryOp::Sub => &self.sub, + BinaryOp::Mul => &self.mul, + BinaryOp::Div => &self.div, + }; + unsafe { + stream + .launch_builder(func) + .arg(lhs) + .arg(rhs) + .arg(output) + .arg(&len) + .launch(cfg) + } + .map(|_| ()) + .map_err(map_driver_err) + } + } + } + + /// CUDA kernels expressed as C strings compiled at runtime with NVRTC. + const SOURCE: &str = r#" // NVRTC treats this source as C++, so we pull in the headers that define the // math intrinsics (`sqrtf`, `expf`, `logf`) and the `size_t` type that we use for // element counts. @@ -30,11 +185,8 @@ const KERNEL_SOURCE: &str = r#" extern "C" __global__ void unary_neg(const float* __restrict__ inp, float* __restrict__ out, size_t n) { - // Convert the 3-D CUDA launch geometry into a single linear index. const size_t idx = GLOBAL_INDEX(); - // Guard against the final, partially-filled block writing past the end. if (idx < n) { - // Apply the unary operation in-place for this linear element. out[idx] = -inp[idx]; } } @@ -116,37 +268,12 @@ extern "C" __global__ void binary_div(const float* __restrict__ lhs, } } "#; - -/// Lazily constructed global CUDA backend. -/// -/// We initialize everything on first use so that simply importing the Python -/// extension does not require the user to have a CUDA device installed. -static CUDA_BACKEND: OnceCell> = OnceCell::new(); - -fn map_driver_err(err: driver::result::DriverError) -> BackendError { - BackendError::Cuda(err.to_string()) -} - -struct UnaryKernels { - neg: CudaFunction, - sqrt: CudaFunction, - exp: CudaFunction, - log: CudaFunction, - relu: CudaFunction, -} - -struct BinaryKernels { - add: CudaFunction, - sub: CudaFunction, - mul: CudaFunction, - div: CudaFunction, } /// Holds the CUDA context, primary stream, and all kernel handles shared by GPU ops. pub struct CudaBackend { stream: Arc, - unary: UnaryKernels, - binary: BinaryKernels, + kernels: kernels::Kernels, } impl fmt::Debug for CudaBackend { @@ -158,74 +285,35 @@ impl fmt::Debug for CudaBackend { } impl CudaBackend { - /// Returns the process-wide CUDA backend instance, creating it on first use. - pub fn global() -> BackendResult<&'static CudaBackend> { - CUDA_BACKEND - .get_or_init(init_backend) - .as_ref() - .map_err(Clone::clone) + /// Creates a new CUDA backend by initializing the context, default stream, and kernels. + pub fn new() -> BackendResult { + Self::init() } -} -fn init_backend() -> Result { - let device_count = CudaContext::device_count().map_err(map_driver_err)?; - if device_count <= 0 { - return Err(BackendError::CudaUnavailable( - "no CUDA devices detected".to_string(), - )); - } + fn init() -> Result { + let device_count = CudaContext::device_count().map_err(map_driver_err)?; + if device_count <= 0 { + return Err(BackendError::CudaUnavailable( + "no CUDA devices detected".to_string(), + )); + } - // Acquire the primary context for device 0 and keep its default stream alive. - // We intentionally pick device 0 for now; extending this to support multiple - // GPUs just means plumbing an ordinal through the Python API and reusing the - // exact same initialization procedure. - let ctx = CudaContext::new(0).map_err(map_driver_err)?; - let stream = ctx.default_stream(); - - // Compile the small collection of pointwise kernels. NVRTC gives us a way to - // embed plain CUDA C++ as a string and turn it into PTX at runtime, which keeps - // the crate self-contained and removes the need for a separate `nvcc` build - // step in the user environment. - let ptx = - nvrtc::compile_ptx(KERNEL_SOURCE).map_err(|err| BackendError::Cuda(err.to_string()))?; - let module = ctx.load_module(ptx).map_err(map_driver_err)?; - - let unary = UnaryKernels { - neg: module.load_function("unary_neg").map_err(map_driver_err)?, - sqrt: module.load_function("unary_sqrt").map_err(map_driver_err)?, - exp: module.load_function("unary_exp").map_err(map_driver_err)?, - log: module.load_function("unary_log").map_err(map_driver_err)?, - relu: module.load_function("unary_relu").map_err(map_driver_err)?, - }; + // Acquire the primary context for device 0 and keep its default stream alive. + // We intentionally pick device 0 for now; extending this to support multiple + // GPUs just means plumbing an ordinal through the Python API and reusing the + // exact same initialization procedure. + let ctx = CudaContext::new(0).map_err(map_driver_err)?; + let stream = ctx.default_stream(); - let binary = BinaryKernels { - add: module.load_function("binary_add").map_err(map_driver_err)?, - sub: module.load_function("binary_sub").map_err(map_driver_err)?, - mul: module.load_function("binary_mul").map_err(map_driver_err)?, - div: module.load_function("binary_div").map_err(map_driver_err)?, - }; + // Load the compiled kernels into the new context. + let kernels = kernels::Kernels::compile(&ctx)?; - Ok(CudaBackend { - stream, - unary, - binary, - }) + Ok(CudaBackend { stream, kernels }) + } } impl CudaBackend { - fn ensure_contiguous(view: &View) -> BackendResult<()> { - if view.is_contiguous() { - Ok(()) - } else { - // For now we only expose element-wise kernels, so we require buffers - // to be densely packed in memory. The Python binding already exposes - // `.contiguous()` to users, so we surface the same error that the CPU - // backend reports in the mismatch case. - Err(BackendError::NotContiguous) - } - } - - fn to_device_slice(&self, view: &View) -> BackendResult> { + fn copy_to_device(&self, view: &View) -> BackendResult> { // The storage currently lives in host memory. Before we can launch a // kernel we need to stage that data on the device. `memcpy_stod` // allocates the device buffer for us and schedules the host-to-device @@ -234,11 +322,7 @@ impl CudaBackend { self.stream.memcpy_stod(slice).map_err(map_driver_err) } - fn view_from_device_slice( - &self, - slice: CudaSlice, - shape: &[usize], - ) -> BackendResult { + fn copy_from_device(&self, slice: CudaSlice, shape: &[usize]) -> BackendResult { // The kernel launch is asynchronous with respect to the host. We force // completion here so that when we copy the results back we are guaranteed // to see the writes that the GPU just produced. @@ -251,24 +335,19 @@ impl CudaBackend { // allocation. We still tag it as `Device::Cuda` to preserve the logical // device placement that the caller requested. let inner = StorageInner::from_vec(host, Device::Cuda); - Ok(View { - inner: Arc::new(inner), - offset: 0, - shape: shape.to_vec(), - strides: contiguous_strides(shape), - }) + Ok(view_from_storage(inner, shape)) } } impl Backend for CudaBackend { fn unary(&self, op: UnaryOp, a: &View) -> BackendResult { - Self::ensure_contiguous(a)?; + require_contiguous(a)?; // Upload the source tensor to device memory so the kernel can consume it. // We do this lazily on every invocation; a production implementation would // want to keep allocations around, but this keeps the control flow easy to // follow while still exercising the GPU cores. - let device_in = self.to_device_slice(a)?; + let device_in = self.copy_to_device(a)?; let mut device_out = self .stream .alloc_zeros::(a.numel()) @@ -284,50 +363,21 @@ impl Backend for CudaBackend { let cfg = LaunchConfig::for_num_elems(a.numel() as u32); let len = a.numel(); - // The launch builder consumes immutable or mutable references to CUDA - // slices. We build the argument list in the exact same way for every - // unary op and only vary the function handle. - let mut launch = |func: &CudaFunction| -> BackendResult<()> { - // Launching work on a CUDA stream is inherently unsafe because the - // compiler cannot prove that the raw pointers we pass stay valid for - // the duration of the kernel. Every argument we push here is owned by - // a `CudaSlice`, so the buffers remain alive until the launch has - // completed. - unsafe { - self.stream - .launch_builder(func) - .arg(&device_in) - .arg(&mut device_out) - .arg(&len) - .launch(cfg) - } - .map(|_| ()) - .map_err(map_driver_err) - }; - - match op { - UnaryOp::Neg => launch(&self.unary.neg)?, - UnaryOp::Sqrt => launch(&self.unary.sqrt)?, - UnaryOp::Exp => launch(&self.unary.exp)?, - UnaryOp::Log => launch(&self.unary.log)?, - UnaryOp::Relu => launch(&self.unary.relu)?, - } + self.kernels + .launch_unary(&self.stream, op, &device_in, &mut device_out, len, cfg)?; - self.view_from_device_slice(device_out, &a.shape) + self.copy_from_device(device_out, &a.shape) } fn binary(&self, op: BinaryOp, a: &View, b: &View) -> BackendResult { - Self::ensure_contiguous(a)?; - Self::ensure_contiguous(b)?; - - if a.numel() != b.numel() { - return Err(BackendError::ShapeMismatch); - } + require_contiguous(a)?; + require_contiguous(b)?; + require_same_numel(a, b)?; // Mirror the upload path we used for the unary kernels, but stage both // operands on the device. - let device_a = self.to_device_slice(a)?; - let device_b = self.to_device_slice(b)?; + let device_a = self.copy_to_device(a)?; + let device_b = self.copy_to_device(b)?; let mut device_out = self .stream .alloc_zeros::(a.numel()) @@ -340,32 +390,16 @@ impl Backend for CudaBackend { let cfg = LaunchConfig::for_num_elems(a.numel() as u32); let len = a.numel(); - // Build-and-launch pipeline that mirrors the unary helper, this time - // attaching both input buffers. - let mut launch = |func: &CudaFunction| -> BackendResult<()> { - // Same safety story as the unary path: we guarantee that the - // `CudaSlice` lifetimes outlive the kernel execution so the raw - // pointers remain valid for the GPU. - unsafe { - self.stream - .launch_builder(func) - .arg(&device_a) - .arg(&device_b) - .arg(&mut device_out) - .arg(&len) - .launch(cfg) - } - .map(|_| ()) - .map_err(map_driver_err) - }; - - match op { - BinaryOp::Add => launch(&self.binary.add)?, - BinaryOp::Sub => launch(&self.binary.sub)?, - BinaryOp::Mul => launch(&self.binary.mul)?, - BinaryOp::Div => launch(&self.binary.div)?, - } - - self.view_from_device_slice(device_out, &a.shape) + self.kernels.launch_binary( + &self.stream, + op, + &device_a, + &device_b, + &mut device_out, + len, + cfg, + )?; + + self.copy_from_device(device_out, &a.shape) } } diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 993c5f3..653381c 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,6 +1,10 @@ use thiserror::Error; +use std::sync::Arc; + +use crate::core::storage::StorageInner; use crate::core::view::{contiguous_strides, View}; +use crate::device::Device; #[derive(Debug, Clone, Copy)] pub enum UnaryOp { @@ -25,6 +29,8 @@ pub enum BackendError { NotContiguous, #[error("shape mismatch")] ShapeMismatch, + #[error("backend for device {0:?} is not implemented")] + UnsupportedDevice(Device), #[error("cuda runtime error: {0}")] Cuda(String), #[error("cuda device unavailable: {0}")] @@ -38,69 +44,35 @@ pub trait Backend { fn binary(&self, op: BinaryOp, a: &View, b: &View) -> BackendResult; } -// === CPU backend (contiguous-only for now) === +fn require_contiguous(view: &View) -> BackendResult<()> { + if view.is_contiguous() { + Ok(()) + } else { + Err(BackendError::NotContiguous) + } +} -use std::sync::Arc; +fn require_same_numel(a: &View, b: &View) -> BackendResult<()> { + if a.numel() == b.numel() { + Ok(()) + } else { + Err(BackendError::ShapeMismatch) + } +} -use crate::core::storage::StorageInner; -use crate::device::Device; +fn view_from_storage(inner: StorageInner, shape: &[usize]) -> View { + View { + inner: Arc::new(inner), + offset: 0, + shape: shape.to_vec(), + strides: contiguous_strides(shape), + } +} +mod cpu; mod cuda; -mod kernels_simd; - -pub use cuda::CudaBackend; -pub struct CpuBackend; +pub mod registry; -impl Backend for CpuBackend { - fn unary(&self, op: UnaryOp, a: &View) -> BackendResult { - if !a.is_contiguous() { - return Err(BackendError::NotContiguous); - } - let a_slice = a.inner.as_slice(a.offset, a.numel()); - let mut out_inner = StorageInner::new_full(0.0, a.numel(), Device::Cpu); - { - let out_slice = out_inner.as_mut_slice(0, a.numel()); - match op { - UnaryOp::Neg => kernels_simd::unary_ops::neg(a_slice, out_slice), - UnaryOp::Sqrt => kernels_simd::unary_ops::sqrt(a_slice, out_slice), - UnaryOp::Exp => kernels_simd::unary_ops::exp(a_slice, out_slice), - UnaryOp::Log => kernels_simd::unary_ops::log(a_slice, out_slice), - UnaryOp::Relu => kernels_simd::unary_ops::relu(a_slice, out_slice), - } - } - Ok(View { - inner: Arc::new(out_inner), - offset: 0, - shape: a.shape.clone(), - strides: contiguous_strides(&a.shape), - }) - } - - fn binary(&self, op: BinaryOp, a: &View, b: &View) -> BackendResult { - if !a.is_contiguous() || !b.is_contiguous() { - return Err(BackendError::NotContiguous); - } - if a.numel() != b.numel() { - return Err(BackendError::ShapeMismatch); - } - let a_slice = a.inner.as_slice(a.offset, a.numel()); - let b_slice = b.inner.as_slice(b.offset, b.numel()); - let mut out_inner = StorageInner::new_full(0.0, a.numel(), Device::Cpu); - { - let out_slice = out_inner.as_mut_slice(0, a.numel()); - match op { - BinaryOp::Add => kernels_simd::binary_ops::add(a_slice, b_slice, out_slice), - BinaryOp::Sub => kernels_simd::binary_ops::sub(a_slice, b_slice, out_slice), - BinaryOp::Mul => kernels_simd::binary_ops::mul(a_slice, b_slice, out_slice), - BinaryOp::Div => kernels_simd::binary_ops::div(a_slice, b_slice, out_slice), - } - } - Ok(View { - inner: Arc::new(out_inner), - offset: 0, - shape: a.shape.clone(), - strides: contiguous_strides(&a.shape), - }) - } -} +pub use cpu::CpuBackend; +pub use cuda::CudaBackend; diff --git a/src/backend/registry.rs b/src/backend/registry.rs new file mode 100644 index 0000000..dc2094c --- /dev/null +++ b/src/backend/registry.rs @@ -0,0 +1,30 @@ +use once_cell::sync::OnceCell; + +use super::{Backend, BackendError, BackendResult, CpuBackend, CudaBackend}; +use crate::device::Device; + +static CPU_BACKEND: CpuBackend = CpuBackend; +static CUDA_BACKEND: OnceCell> = OnceCell::new(); + +fn cuda_backend() -> BackendResult<&'static CudaBackend> { + CUDA_BACKEND + .get_or_init(|| CudaBackend::new()) + .as_ref() + .map_err(Clone::clone) +} + +pub fn cpu() -> &'static CpuBackend { + &CPU_BACKEND +} + +pub fn cuda() -> BackendResult<&'static CudaBackend> { + cuda_backend() +} + +pub fn get(device: Device) -> BackendResult<&'static dyn Backend> { + match device { + Device::Cpu => Ok(cpu() as &dyn Backend), + Device::Cuda => cuda().map(|backend| backend as &dyn Backend), + other => Err(BackendError::UnsupportedDevice(other)), + } +} diff --git a/src/py/mod.rs b/src/py/mod.rs index bf7510b..a936939 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use pyo3::prelude::*; -use crate::backend::{Backend, BackendError, BinaryOp, CpuBackend, CudaBackend, UnaryOp}; +use crate::backend::{registry, Backend, BackendError, BinaryOp, UnaryOp}; use crate::core::{ reduce::{self, AxisError}, storage::StorageInner, @@ -25,8 +25,16 @@ fn backend_err_to_py(err: BackendError) -> PyErr { pyo3::exceptions::PyRuntimeError::new_err(err.to_string()) } -fn cuda_backend_or_pyerr() -> Result<&'static CudaBackend, PyErr> { - CudaBackend::global().map_err(backend_err_to_py) +fn backend_for_device(device: Device) -> Result<&'static dyn Backend, PyErr> { + match registry::get(device) { + Ok(backend) => Ok(backend), + Err(BackendError::UnsupportedDevice(device)) => { + Err(pyo3::exceptions::PyNotImplementedError::new_err( + BackendError::UnsupportedDevice(device).to_string(), + )) + } + Err(err) => Err(backend_err_to_py(err)), + } } #[cfg(test)] @@ -286,206 +294,80 @@ impl StorageView { } pub fn neg(&self) -> PyResult { - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .unary(UnaryOp::Neg, &self.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - crate::device::Device::Cuda => { - let backend = cuda_backend_or_pyerr()?; - let out = backend - .unary(UnaryOp::Neg, &self.view) - .map_err(backend_err_to_py)?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "unary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .unary(UnaryOp::Neg, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn sqrt(&self) -> PyResult { - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .unary(UnaryOp::Sqrt, &self.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - crate::device::Device::Cuda => { - let backend = cuda_backend_or_pyerr()?; - let out = backend - .unary(UnaryOp::Sqrt, &self.view) - .map_err(backend_err_to_py)?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "unary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .unary(UnaryOp::Sqrt, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn relu(&self) -> PyResult { - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .unary(UnaryOp::Relu, &self.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - crate::device::Device::Cuda => { - let backend = cuda_backend_or_pyerr()?; - let out = backend - .unary(UnaryOp::Relu, &self.view) - .map_err(backend_err_to_py)?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "unary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .unary(UnaryOp::Relu, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn exp(&self) -> PyResult { - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .unary(UnaryOp::Exp, &self.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - crate::device::Device::Cuda => { - let backend = cuda_backend_or_pyerr()?; - let out = backend - .unary(UnaryOp::Exp, &self.view) - .map_err(backend_err_to_py)?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "unary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .unary(UnaryOp::Exp, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn log(&self) -> PyResult { - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .unary(UnaryOp::Log, &self.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - crate::device::Device::Cuda => { - let backend = cuda_backend_or_pyerr()?; - let out = backend - .unary(UnaryOp::Log, &self.view) - .map_err(backend_err_to_py)?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "unary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .unary(UnaryOp::Log, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn add(&self, other: &StorageView) -> PyResult { if self.view.inner.device() != other.view.inner.device() { return Err(pyo3::exceptions::PyValueError::new_err("device mismatch")); } - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .binary(BinaryOp::Add, &self.view, &other.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - crate::device::Device::Cuda => { - let backend = cuda_backend_or_pyerr()?; - let out = backend - .binary(BinaryOp::Add, &self.view, &other.view) - .map_err(backend_err_to_py)?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "binary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .binary(BinaryOp::Add, &self.view, &other.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn sub(&self, other: &StorageView) -> PyResult { if self.view.inner.device() != other.view.inner.device() { return Err(pyo3::exceptions::PyValueError::new_err("device mismatch")); } - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .binary(BinaryOp::Sub, &self.view, &other.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - crate::device::Device::Cuda => { - let backend = cuda_backend_or_pyerr()?; - let out = backend - .binary(BinaryOp::Sub, &self.view, &other.view) - .map_err(backend_err_to_py)?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "binary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .binary(BinaryOp::Sub, &self.view, &other.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn mul(&self, other: &StorageView) -> PyResult { if self.view.inner.device() != other.view.inner.device() { return Err(pyo3::exceptions::PyValueError::new_err("device mismatch")); } - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .binary(BinaryOp::Mul, &self.view, &other.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - crate::device::Device::Cuda => { - let backend = cuda_backend_or_pyerr()?; - let out = backend - .binary(BinaryOp::Mul, &self.view, &other.view) - .map_err(backend_err_to_py)?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "binary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .binary(BinaryOp::Mul, &self.view, &other.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn div(&self, other: &StorageView) -> PyResult { if self.view.inner.device() != other.view.inner.device() { return Err(pyo3::exceptions::PyValueError::new_err("device mismatch")); } - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .binary(BinaryOp::Div, &self.view, &other.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - crate::device::Device::Cuda => { - let backend = cuda_backend_or_pyerr()?; - let out = backend - .binary(BinaryOp::Div, &self.view, &other.view) - .map_err(backend_err_to_py)?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "binary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .binary(BinaryOp::Div, &self.view, &other.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } #[pyo3(signature = (dim=None, keepdim=false))] @@ -497,8 +379,8 @@ impl StorageView { let reduced = reduce::reduce_sum(&self.view, axis, keepdim); Ok(StorageView { view: reduced }) } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "sum not implemented for this device", + other => Err(pyo3::exceptions::PyNotImplementedError::new_err( + BackendError::UnsupportedDevice(other).to_string(), )), } } @@ -512,8 +394,8 @@ impl StorageView { let reduced = reduce::reduce_max(&self.view, axis, keepdim); Ok(StorageView { view: reduced }) } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "max not implemented for this device", + other => Err(pyo3::exceptions::PyNotImplementedError::new_err( + BackendError::UnsupportedDevice(other).to_string(), )), } } diff --git a/src/tests/backend.rs b/src/tests/backend.rs index 9c8e79c..767cd62 100644 --- a/src/tests/backend.rs +++ b/src/tests/backend.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::backend::{Backend, BinaryOp, CpuBackend, CudaBackend, UnaryOp}; +use crate::backend::{registry, Backend, BackendError, BinaryOp, CudaBackend, UnaryOp}; use crate::core::{storage::StorageInner, view::View}; use crate::device::Device; @@ -15,7 +15,7 @@ fn cuda_view(v: Vec) -> View { } fn cuda_backend_or_skip(test_name: &str) -> Option<&'static CudaBackend> { - match CudaBackend::global() { + match registry::cuda() { Ok(be) => Some(be), Err(err) => { eprintln!("skipping {test_name}: {err}"); @@ -84,7 +84,7 @@ where #[test] fn unary_neg_matches_scalar() { let a = vec_view(vec![1.0, -2.0, 3.0, -4.5, 0.25]); - let be = CpuBackend; + let be = registry::cpu(); let out = be.unary(UnaryOp::Neg, &a).unwrap(); let actual = out.inner.as_slice(out.offset, out.numel()); let expect: Vec = a @@ -99,7 +99,7 @@ fn unary_neg_matches_scalar() { #[test] fn unary_sqrt_matches_scalar() { let a = vec_view(vec![0.0, 0.25, 1.0, 4.0, 9.0, 16.0, 2.25]); - let be = CpuBackend; + let be = registry::cpu(); let out = be.unary(UnaryOp::Sqrt, &a).unwrap(); let actual = out.inner.as_slice(out.offset, out.numel()); let expect: Vec = a @@ -116,7 +116,7 @@ fn unary_sqrt_matches_scalar() { #[test] fn unary_relu_matches_scalar() { let a = vec_view(vec![-3.0, -0.5, 0.0, 0.5, 2.0, 5.0]); - let be = CpuBackend; + let be = registry::cpu(); let out = be.unary(UnaryOp::Relu, &a).unwrap(); let actual = out.inner.as_slice(out.offset, out.numel()); let expect: Vec = a @@ -134,7 +134,7 @@ fn unary_relu_matches_scalar() { fn binary_add_matches_scalar() { let a = vec_view((0..130).map(|i| i as f32 * 0.5).collect()); // span SIMD + remainder let b = vec_view((0..130).map(|i| i as f32 * -0.25).collect()); - let be = CpuBackend; + let be = registry::cpu(); let out = be.binary(BinaryOp::Add, &a, &b).unwrap(); let actual = out.inner.as_slice(out.offset, out.numel()); let a_s = a.inner.as_slice(a.offset, a.numel()); @@ -148,7 +148,7 @@ fn binary_add_matches_scalar() { fn binary_shape_mismatch_error() { let a = vec_view(vec![1.0, 2.0, 3.0]); let b = vec_view(vec![4.0, 5.0]); - let be = CpuBackend; + let be = registry::cpu(); let err = be.binary(BinaryOp::Add, &a, &b).unwrap_err(); assert!(matches!(err, crate::backend::BackendError::ShapeMismatch)); } @@ -158,7 +158,7 @@ fn unary_not_contiguous_error() { let a = vec_view((0..12).map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); let a_nc = a.permute(&[1, 0]); assert!(!a_nc.is_contiguous()); - let be = CpuBackend; + let be = registry::cpu(); let err = be.unary(UnaryOp::Neg, &a_nc).unwrap_err(); assert!(matches!(err, crate::backend::BackendError::NotContiguous)); } @@ -168,11 +168,34 @@ fn binary_not_contiguous_error() { let a = vec_view((0..12).map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); let b = vec_view((0..12).rev().map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); let a_nc = a.permute(&[1, 0]); - let be = CpuBackend; + let be = registry::cpu(); let err = be.binary(BinaryOp::Add, &a_nc, &b).unwrap_err(); assert!(matches!(err, crate::backend::BackendError::NotContiguous)); } +#[test] +fn registry_get_cpu_backend() { + let backend = registry::get(Device::Cpu).unwrap(); + let view = vec_view(vec![1.0, -2.0, 3.0]); + let out = backend.unary(UnaryOp::Neg, &view).unwrap(); + assert_eq!( + out.inner.as_slice(out.offset, out.numel()), + &[-1.0, 2.0, -3.0] + ); +} + +#[test] +fn registry_reports_unsupported_device() { + let err = match registry::get(Device::Metal) { + Ok(_) => panic!("expected registry lookup to fail for Metal"), + Err(err) => err, + }; + assert!(matches!( + err, + BackendError::UnsupportedDevice(Device::Metal) + )); +} + #[test] fn cuda_unary_neg_matches_cpu_when_device_available() { run_cuda_unary_case( From 42a7a5b270e470eefa3276c80e892bde07da09ad Mon Sep 17 00:00:00 2001 From: manoflearning <77jwk0724@gmail.com> Date: Thu, 18 Sep 2025 00:28:13 +0900 Subject: [PATCH 4/5] make cudarc optional --- Cargo.toml | 3 ++- src/backend/mod.rs | 3 +++ src/backend/registry.rs | 24 ++++++++++++++++++++---- src/tests/backend.rs | 35 ++++++++++++++++++++++++++++++++++- 4 files changed, 59 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b644614..47e1ad1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,12 +12,13 @@ pyo3 = { version = "0.24.1", features = ["extension-module"] } rand = { version = "0.8.5" } rand_distr = "0.4" thiserror = "1" -cudarc = { version = "0.17", features = ["driver", "nvrtc", "cuda-version-from-build-system"] } +cudarc = { version = "0.17", optional = true, features = ["driver", "nvrtc", "cuda-version-from-build-system"] } once_cell = "1.19" [features] default = ["python"] python = [] +cuda = ["cudarc"] abi3 = ["pyo3/abi3-py37", "generate-import-lib"] generate-import-lib = ["pyo3/generate-import-lib"] diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 653381c..b85fa2e 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -31,6 +31,7 @@ pub enum BackendError { ShapeMismatch, #[error("backend for device {0:?} is not implemented")] UnsupportedDevice(Device), + #[cfg(feature = "cuda")] #[error("cuda runtime error: {0}")] Cuda(String), #[error("cuda device unavailable: {0}")] @@ -70,9 +71,11 @@ fn view_from_storage(inner: StorageInner, shape: &[usize]) -> View { } mod cpu; +#[cfg(feature = "cuda")] mod cuda; pub mod registry; pub use cpu::CpuBackend; +#[cfg(feature = "cuda")] pub use cuda::CudaBackend; diff --git a/src/backend/registry.rs b/src/backend/registry.rs index dc2094c..7e5f9b9 100644 --- a/src/backend/registry.rs +++ b/src/backend/registry.rs @@ -1,11 +1,15 @@ -use once_cell::sync::OnceCell; - -use super::{Backend, BackendError, BackendResult, CpuBackend, CudaBackend}; +#[cfg(feature = "cuda")] +use super::CudaBackend; +use super::{Backend, BackendError, BackendResult, CpuBackend}; use crate::device::Device; +#[cfg(feature = "cuda")] +use once_cell::sync::OnceCell; static CPU_BACKEND: CpuBackend = CpuBackend; +#[cfg(feature = "cuda")] static CUDA_BACKEND: OnceCell> = OnceCell::new(); +#[cfg(feature = "cuda")] fn cuda_backend() -> BackendResult<&'static CudaBackend> { CUDA_BACKEND .get_or_init(|| CudaBackend::new()) @@ -17,6 +21,7 @@ pub fn cpu() -> &'static CpuBackend { &CPU_BACKEND } +#[cfg(feature = "cuda")] pub fn cuda() -> BackendResult<&'static CudaBackend> { cuda_backend() } @@ -24,7 +29,18 @@ pub fn cuda() -> BackendResult<&'static CudaBackend> { pub fn get(device: Device) -> BackendResult<&'static dyn Backend> { match device { Device::Cpu => Ok(cpu() as &dyn Backend), - Device::Cuda => cuda().map(|backend| backend as &dyn Backend), + Device::Cuda => { + #[cfg(feature = "cuda")] + { + cuda().map(|backend| backend as &dyn Backend) + } + #[cfg(not(feature = "cuda"))] + { + Err(BackendError::CudaUnavailable( + "crate compiled without the `cuda` feature".to_string(), + )) + } + } other => Err(BackendError::UnsupportedDevice(other)), } } diff --git a/src/tests/backend.rs b/src/tests/backend.rs index 767cd62..56237b5 100644 --- a/src/tests/backend.rs +++ b/src/tests/backend.rs @@ -1,6 +1,8 @@ use std::sync::Arc; -use crate::backend::{registry, Backend, BackendError, BinaryOp, CudaBackend, UnaryOp}; +#[cfg(feature = "cuda")] +use crate::backend::CudaBackend; +use crate::backend::{registry, Backend, BackendError, BinaryOp, UnaryOp}; use crate::core::{storage::StorageInner, view::View}; use crate::device::Device; @@ -9,11 +11,13 @@ fn vec_view(v: Vec) -> View { View::from_inner_1d(inner) } +#[cfg(feature = "cuda")] fn cuda_view(v: Vec) -> View { let inner = Arc::new(StorageInner::from_vec(v, Device::Cuda)); View::from_inner_1d(inner) } +#[cfg(feature = "cuda")] fn cuda_backend_or_skip(test_name: &str) -> Option<&'static CudaBackend> { match registry::cuda() { Ok(be) => Some(be), @@ -24,6 +28,7 @@ fn cuda_backend_or_skip(test_name: &str) -> Option<&'static CudaBackend> { } } +#[cfg(feature = "cuda")] fn assert_close(got: f32, expected: f32, tol: f32, idx: usize) { let diff = (got - expected).abs(); assert!( @@ -32,6 +37,7 @@ fn assert_close(got: f32, expected: f32, tol: f32, idx: usize) { ); } +#[cfg(feature = "cuda")] fn assert_vec_close(actual: &[f32], expected: &[f32], tol: f32) { assert_eq!(actual.len(), expected.len(), "length mismatch"); for (idx, (got, exp)) in actual.iter().zip(expected.iter()).enumerate() { @@ -39,6 +45,7 @@ fn assert_vec_close(actual: &[f32], expected: &[f32], tol: f32) { } } +#[cfg(feature = "cuda")] fn run_cuda_unary_case(name: &str, op: UnaryOp, input: &[f32], reference: F) where F: Fn(f32) -> f32, @@ -57,6 +64,7 @@ where assert_vec_close(actual, expected.as_slice(), 1e-5); } +#[cfg(feature = "cuda")] fn run_cuda_binary_case(name: &str, op: BinaryOp, lhs: &[f32], rhs: &[f32], reference: F) where F: Fn(f32, f32) -> f32, @@ -196,6 +204,20 @@ fn registry_reports_unsupported_device() { )); } +#[cfg(not(feature = "cuda"))] +#[test] +fn registry_reports_cuda_unavailable_when_feature_disabled() { + let err = match registry::get(Device::Cuda) { + Ok(_) => panic!("expected CUDA backend lookup to fail without feature"), + Err(err) => err, + }; + assert!(matches!( + err, + BackendError::CudaUnavailable(message) if message.contains("`cuda` feature") + )); +} + +#[cfg(feature = "cuda")] #[test] fn cuda_unary_neg_matches_cpu_when_device_available() { run_cuda_unary_case( @@ -206,6 +228,7 @@ fn cuda_unary_neg_matches_cpu_when_device_available() { ); } +#[cfg(feature = "cuda")] #[test] fn cuda_unary_sqrt_matches_cpu_when_device_available() { run_cuda_unary_case( @@ -216,6 +239,7 @@ fn cuda_unary_sqrt_matches_cpu_when_device_available() { ); } +#[cfg(feature = "cuda")] #[test] fn cuda_unary_exp_matches_cpu_when_device_available() { run_cuda_unary_case( @@ -226,6 +250,7 @@ fn cuda_unary_exp_matches_cpu_when_device_available() { ); } +#[cfg(feature = "cuda")] #[test] fn cuda_unary_log_matches_cpu_when_device_available() { run_cuda_unary_case( @@ -236,6 +261,7 @@ fn cuda_unary_log_matches_cpu_when_device_available() { ); } +#[cfg(feature = "cuda")] #[test] fn cuda_unary_relu_matches_cpu_when_device_available() { run_cuda_unary_case( @@ -246,6 +272,7 @@ fn cuda_unary_relu_matches_cpu_when_device_available() { ); } +#[cfg(feature = "cuda")] #[test] fn cuda_binary_add_matches_cpu_when_device_available() { let lhs: Vec = (0..64).map(|i| i as f32 * 0.5).collect(); @@ -253,6 +280,7 @@ fn cuda_binary_add_matches_cpu_when_device_available() { run_cuda_binary_case("cuda_binary_add", BinaryOp::Add, &lhs, &rhs, |x, y| x + y); } +#[cfg(feature = "cuda")] #[test] fn cuda_binary_sub_matches_cpu_when_device_available() { let lhs: Vec = (0..64).map(|i| i as f32 * 0.5).collect(); @@ -260,6 +288,7 @@ fn cuda_binary_sub_matches_cpu_when_device_available() { run_cuda_binary_case("cuda_binary_sub", BinaryOp::Sub, &lhs, &rhs, |x, y| x - y); } +#[cfg(feature = "cuda")] #[test] fn cuda_binary_mul_matches_cpu_when_device_available() { let lhs: Vec = (0..64).map(|i| i as f32 * 0.5).collect(); @@ -267,6 +296,7 @@ fn cuda_binary_mul_matches_cpu_when_device_available() { run_cuda_binary_case("cuda_binary_mul", BinaryOp::Mul, &lhs, &rhs, |x, y| x * y); } +#[cfg(feature = "cuda")] #[test] fn cuda_binary_div_matches_cpu_when_device_available() { let lhs: Vec = (1..65).map(|i| i as f32).collect(); @@ -274,6 +304,7 @@ fn cuda_binary_div_matches_cpu_when_device_available() { run_cuda_binary_case("cuda_binary_div", BinaryOp::Div, &lhs, &rhs, |x, y| x / y); } +#[cfg(feature = "cuda")] #[test] fn cuda_unary_not_contiguous_error_when_device_available() { let backend = match cuda_backend_or_skip("cuda_unary_not_contiguous_error") { @@ -286,6 +317,7 @@ fn cuda_unary_not_contiguous_error_when_device_available() { assert!(matches!(err, crate::backend::BackendError::NotContiguous)); } +#[cfg(feature = "cuda")] #[test] fn cuda_binary_not_contiguous_error_when_device_available() { let backend = match cuda_backend_or_skip("cuda_binary_not_contiguous_error") { @@ -301,6 +333,7 @@ fn cuda_binary_not_contiguous_error_when_device_available() { assert!(matches!(err, crate::backend::BackendError::NotContiguous)); } +#[cfg(feature = "cuda")] #[test] fn cuda_binary_shape_mismatch_error_when_device_available() { let backend = match cuda_backend_or_skip("cuda_binary_shape_mismatch_error") { From 2870a5e78f2c57c3f4e0b07259cf32efacd88789 Mon Sep 17 00:00:00 2001 From: manoflearning <77jwk0724@gmail.com> Date: Thu, 18 Sep 2025 00:29:35 +0900 Subject: [PATCH 5/5] fix ci --- .github/workflows/main.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index ca1a00c..522d089 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -36,7 +36,7 @@ jobs: run: cargo fmt --all -- --check - name: Run clippy - run: cargo clippy --workspace --all-features + run: cargo clippy check_style_python: name: Check file formatting and style (Python) @@ -74,7 +74,7 @@ jobs: - uses: Swatinem/rust-cache@v2 - name: Run tests - run: cargo test --workspace --all-features + run: cargo test test_python: name: Tests (Python)