diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index ca1a00c..522d089 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -36,7 +36,7 @@ jobs: run: cargo fmt --all -- --check - name: Run clippy - run: cargo clippy --workspace --all-features + run: cargo clippy check_style_python: name: Check file formatting and style (Python) @@ -74,7 +74,7 @@ jobs: - uses: Swatinem/rust-cache@v2 - name: Run tests - run: cargo test --workspace --all-features + run: cargo test test_python: name: Tests (Python) diff --git a/Cargo.lock b/Cargo.lock index 602f351..381432c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,12 +28,23 @@ checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" name = "cranberry" version = "0.1.3" dependencies = [ + "cudarc", + "once_cell", "pyo3", "rand", "rand_distr", "thiserror", ] +[[package]] +name = "cudarc" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72ba848ae5c6f3cb36e71eab5f268763e3fabcabe3f7bc683e16f7fa3d46281e" +dependencies = [ + "libloading", +] + [[package]] name = "find-msvc-tools" version = "0.1.1" @@ -69,6 +80,16 @@ version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +[[package]] +name = "libloading" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" +dependencies = [ + "cfg-if", + "windows-targets", +] + [[package]] name = "libm" version = "0.2.15" @@ -307,6 +328,77 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-targets" +version = "0.53.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "zerocopy" version = "0.8.27" diff --git a/Cargo.toml b/Cargo.toml index ba51973..47e1ad1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,10 +12,13 @@ pyo3 = { version = "0.24.1", features = ["extension-module"] } rand = { version = "0.8.5" } rand_distr = "0.4" thiserror = "1" +cudarc = { version = "0.17", optional = true, features = ["driver", "nvrtc", "cuda-version-from-build-system"] } +once_cell = "1.19" [features] default = ["python"] python = [] +cuda = ["cudarc"] abi3 = ["pyo3/abi3-py37", "generate-import-lib"] generate-import-lib = ["pyo3/generate-import-lib"] diff --git a/README.md b/README.md index 02eef22..5ef93d8 100644 --- a/README.md +++ b/README.md @@ -15,13 +15,14 @@ A small deep learning framework in Rust and Python ## Overview -Cranberry is an educational project exploring how a tensor library, automatic differentiation, and a Rust-backed storage layer fit together. The Python front-end intentionally stays simple while the Rust extension supplies fast contiguous kernels and view manipulation utilities. Everything currently targets CPU and 32-bit floating point tensors. +Cranberry is an educational project exploring how a tensor library, automatic differentiation, and a Rust-backed storage layer fit together. The Python front-end intentionally stays simple while the Rust extension supplies fast contiguous kernels and view manipulation utilities. Everything targets 32-bit floating point tensors and now offers optional CUDA acceleration for pointwise kernels alongside the CPU path. ## Highlights - Python-first `Tensor` API backed by the `StorageView` PyO3 module. - Reverse-mode autograd with topological traversal, supporting gradient tracking through broadcasting and reshape/expand/permute transforms. - Contiguous CPU kernels for unary/binary ops plus sum/max reductions, with broadcasting handled in Python. +- Optional CUDA backend for contiguous unary and binary operations when an NVIDIA GPU and toolkit are available. - Basic neural-network building blocks (`nn.Linear`, `nn.ReLU`, `nn.Sequential`) and stochastic gradient descent in `optim.SGD`. - Visualization helpers for autograd graphs (`cranberry.features.visualize`) and an MNIST downloader with caching (`cranberry.features.datasets`). @@ -56,11 +57,12 @@ Cranberry is an educational project exploring how a tensor library, automatic di **Rust Extension** - `StorageView` exposes contiguous tensor storage, reshaping, expanding, permuting, and random fills. - CPU backend implements SIMD-accelerated unary/binary kernels and reduction routines. +- CUDA backend (via `cudarc`) mirrors the contiguous unary/binary kernels when a CUDA device is detected. - Views currently support up to rank-4 tensors; non-contiguous reshape/permute paths are under construction. ## Limitations & Work in Progress -- CPU-only; GPU/Metal backends are stubbed out. +- CUDA backend currently covers only contiguous unary/binary kernels; Metal remains stubbed out. - Autograd requires scalar losses and does not yet handle slicing/indexing/in-place mutations. - Views must be contiguous for most kernels; slicing and advanced indexing are not implemented. - Only `float32` tensors are supported; dtype promotion and mixed precision are future work. @@ -98,6 +100,12 @@ Optional extras: - `pip install -e .[datasets]` for download progress via `tqdm`. - `pip install -e .[all]` to include every extra. +## CUDA backend + +- Requires an NVIDIA driver and CUDA toolkit (NVRTC must be discoverable via `CUDA_HOME`, `CUDA_PATH`, or the default `/usr/local/cuda`). +- No separate `nvcc` build step is needed—the crate compiles its kernels at runtime using NVRTC. +- At runtime pass `device="cuda"` when creating tensors/storage; contiguous unary and binary ops will execute on the GPU and fall back with a runtime error if no device is present. + ## Quickstart ```python diff --git a/src/backend/kernels_simd.rs b/src/backend/cpu.rs similarity index 65% rename from src/backend/kernels_simd.rs rename to src/backend/cpu.rs index 439cd3d..e301d46 100644 --- a/src/backend/kernels_simd.rs +++ b/src/backend/cpu.rs @@ -1,14 +1,70 @@ use std::simd::num::SimdFloat; use std::simd::{f32x64, StdFloat}; +use super::{ + require_contiguous, require_same_numel, view_from_storage, Backend, BackendResult, BinaryOp, + UnaryOp, +}; +use crate::core::{storage::StorageInner, view::View}; +use crate::device::Device; + const CHUNK_SIZE: usize = 64; -pub mod unary_ops { +pub struct CpuBackend; + +impl CpuBackend { + fn apply_unary(op: UnaryOp, input: &[f32], output: &mut [f32]) { + unary::apply(op, input, output); + } + + fn apply_binary(op: BinaryOp, lhs: &[f32], rhs: &[f32], output: &mut [f32]) { + binary::apply(op, lhs, rhs, output); + } +} + +impl Backend for CpuBackend { + fn unary(&self, op: UnaryOp, a: &View) -> BackendResult { + require_contiguous(a)?; + let input = a.inner.as_slice(a.offset, a.numel()); + let mut storage = StorageInner::new_full(0.0, a.numel(), Device::Cpu); + { + let output = storage.as_mut_slice(0, a.numel()); + Self::apply_unary(op, input, output); + } + Ok(view_from_storage(storage, &a.shape)) + } + + fn binary(&self, op: BinaryOp, a: &View, b: &View) -> BackendResult { + require_contiguous(a)?; + require_contiguous(b)?; + require_same_numel(a, b)?; + let lhs = a.inner.as_slice(a.offset, a.numel()); + let rhs = b.inner.as_slice(b.offset, b.numel()); + let mut storage = StorageInner::new_full(0.0, a.numel(), Device::Cpu); + { + let output = storage.as_mut_slice(0, a.numel()); + Self::apply_binary(op, lhs, rhs, output); + } + Ok(view_from_storage(storage, &a.shape)) + } +} + +pub(crate) mod unary { use super::*; use std::ops::Neg; - pub fn neg(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); + pub fn apply(op: UnaryOp, input: &[f32], output: &mut [f32]) { + assert_eq!(input.len(), output.len()); + match op { + UnaryOp::Neg => neg(input, output), + UnaryOp::Sqrt => sqrt(input, output), + UnaryOp::Exp => exp(input, output), + UnaryOp::Log => log(input, output), + UnaryOp::Relu => relu(input, output), + } + } + + fn neg(a: &[f32], b: &mut [f32]) { let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); a_main @@ -21,8 +77,7 @@ pub mod unary_ops { .for_each(|(a, b)| *b = -a); } - pub fn sqrt(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); + fn sqrt(a: &[f32], b: &mut [f32]) { let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); a_main @@ -35,8 +90,7 @@ pub mod unary_ops { .for_each(|(a, b)| *b = a.sqrt()); } - pub fn relu(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); + fn relu(a: &[f32], b: &mut [f32]) { let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); let zero = f32x64::splat(0.0); @@ -50,8 +104,7 @@ pub mod unary_ops { .for_each(|(a, b)| *b = a.max(0.0)); } - pub fn exp(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); + fn exp(a: &[f32], b: &mut [f32]) { let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); a_main @@ -64,8 +117,7 @@ pub mod unary_ops { .for_each(|(a, b)| *b = a.exp()); } - pub fn log(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); + fn log(a: &[f32], b: &mut [f32]) { let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); a_main @@ -79,12 +131,22 @@ pub mod unary_ops { } } -pub mod binary_ops { +pub(crate) mod binary { use super::*; use std::ops::{Add, Div, Mul, Sub}; - pub fn add(a: &[f32], b: &[f32], c: &mut [f32]) { - assert!(a.len() == b.len() && b.len() == c.len()); + pub fn apply(op: BinaryOp, lhs: &[f32], rhs: &[f32], output: &mut [f32]) { + assert_eq!(lhs.len(), rhs.len()); + assert_eq!(rhs.len(), output.len()); + match op { + BinaryOp::Add => add(lhs, rhs, output), + BinaryOp::Sub => sub(lhs, rhs, output), + BinaryOp::Mul => mul(lhs, rhs, output), + BinaryOp::Div => div(lhs, rhs, output), + } + } + + fn add(a: &[f32], b: &[f32], c: &mut [f32]) { let main = a.len() - a.len() % CHUNK_SIZE; let (a_main, a_rem) = a.split_at(main); let (b_main, b_rem) = b.split_at(main); @@ -105,8 +167,7 @@ pub mod binary_ops { .for_each(|((a, b), c)| *c = a + b); } - pub fn sub(a: &[f32], b: &[f32], c: &mut [f32]) { - assert!(a.len() == b.len() && b.len() == c.len()); + fn sub(a: &[f32], b: &[f32], c: &mut [f32]) { let main = a.len() - a.len() % CHUNK_SIZE; let (a_main, a_rem) = a.split_at(main); let (b_main, b_rem) = b.split_at(main); @@ -127,8 +188,7 @@ pub mod binary_ops { .for_each(|((a, b), c)| *c = a - b); } - pub fn mul(a: &[f32], b: &[f32], c: &mut [f32]) { - assert!(a.len() == b.len() && b.len() == c.len()); + fn mul(a: &[f32], b: &[f32], c: &mut [f32]) { let main = a.len() - a.len() % CHUNK_SIZE; let (a_main, a_rem) = a.split_at(main); let (b_main, b_rem) = b.split_at(main); @@ -149,8 +209,7 @@ pub mod binary_ops { .for_each(|((a, b), c)| *c = a * b); } - pub fn div(a: &[f32], b: &[f32], c: &mut [f32]) { - assert!(a.len() == b.len() && b.len() == c.len()); + fn div(a: &[f32], b: &[f32], c: &mut [f32]) { let main = a.len() - a.len() % CHUNK_SIZE; let (a_main, a_rem) = a.split_at(main); let (b_main, b_rem) = b.split_at(main); diff --git a/src/backend/cuda.rs b/src/backend/cuda.rs new file mode 100644 index 0000000..0ae4159 --- /dev/null +++ b/src/backend/cuda.rs @@ -0,0 +1,405 @@ +use std::{fmt, sync::Arc}; + +use cudarc::driver::{self, CudaContext, CudaSlice, CudaStream, LaunchConfig}; + +use super::{ + require_contiguous, require_same_numel, view_from_storage, Backend, BackendError, + BackendResult, BinaryOp, UnaryOp, +}; +use crate::core::{storage::StorageInner, view::View}; +use crate::device::Device; + +fn map_driver_err(err: driver::result::DriverError) -> BackendError { + BackendError::Cuda(err.to_string()) +} + +mod kernels { + use super::{map_driver_err, BackendError, BackendResult, BinaryOp, UnaryOp}; + use cudarc::driver::{ + CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream, LaunchConfig, PushKernelArg, + }; + use cudarc::nvrtc; + use std::sync::Arc; + + pub(crate) struct Kernels { + unary: unary::Module, + binary: binary::Module, + } + + impl Kernels { + pub fn compile(ctx: &Arc) -> Result { + // Produce PTX from the inline CUDA source. NVRTC jit-compiles at runtime, + // which keeps the crate self-contained and avoids a separate `nvcc` step. + let ptx = + nvrtc::compile_ptx(SOURCE).map_err(|err| BackendError::Cuda(err.to_string()))?; + let module = ctx.load_module(ptx).map_err(map_driver_err)?; + Ok(Self { + unary: unary::Module::load(&module)?, + binary: binary::Module::load(&module)?, + }) + } + + pub fn launch_unary( + &self, + stream: &Arc, + op: UnaryOp, + input: &CudaSlice, + output: &mut CudaSlice, + len: usize, + cfg: LaunchConfig, + ) -> BackendResult<()> { + self.unary.launch(stream, op, input, output, len, cfg) + } + + pub fn launch_binary( + &self, + stream: &Arc, + op: BinaryOp, + lhs: &CudaSlice, + rhs: &CudaSlice, + output: &mut CudaSlice, + len: usize, + cfg: LaunchConfig, + ) -> BackendResult<()> { + self.binary.launch(stream, op, lhs, rhs, output, len, cfg) + } + } + + mod unary { + use super::*; + + pub struct Module { + neg: CudaFunction, + sqrt: CudaFunction, + exp: CudaFunction, + log: CudaFunction, + relu: CudaFunction, + } + + impl Module { + pub fn load(module: &Arc) -> Result { + Ok(Self { + neg: module.load_function("unary_neg").map_err(map_driver_err)?, + sqrt: module.load_function("unary_sqrt").map_err(map_driver_err)?, + exp: module.load_function("unary_exp").map_err(map_driver_err)?, + log: module.load_function("unary_log").map_err(map_driver_err)?, + relu: module.load_function("unary_relu").map_err(map_driver_err)?, + }) + } + + pub fn launch( + &self, + stream: &Arc, + op: UnaryOp, + input: &CudaSlice, + output: &mut CudaSlice, + len: usize, + cfg: LaunchConfig, + ) -> BackendResult<()> { + let func = match op { + UnaryOp::Neg => &self.neg, + UnaryOp::Sqrt => &self.sqrt, + UnaryOp::Exp => &self.exp, + UnaryOp::Log => &self.log, + UnaryOp::Relu => &self.relu, + }; + unsafe { + stream + .launch_builder(func) + .arg(input) + .arg(output) + .arg(&len) + .launch(cfg) + } + .map(|_| ()) + .map_err(map_driver_err) + } + } + } + + mod binary { + use super::*; + + pub struct Module { + add: CudaFunction, + sub: CudaFunction, + mul: CudaFunction, + div: CudaFunction, + } + + impl Module { + pub fn load(module: &Arc) -> Result { + Ok(Self { + add: module.load_function("binary_add").map_err(map_driver_err)?, + sub: module.load_function("binary_sub").map_err(map_driver_err)?, + mul: module.load_function("binary_mul").map_err(map_driver_err)?, + div: module.load_function("binary_div").map_err(map_driver_err)?, + }) + } + + pub fn launch( + &self, + stream: &Arc, + op: BinaryOp, + lhs: &CudaSlice, + rhs: &CudaSlice, + output: &mut CudaSlice, + len: usize, + cfg: LaunchConfig, + ) -> BackendResult<()> { + let func = match op { + BinaryOp::Add => &self.add, + BinaryOp::Sub => &self.sub, + BinaryOp::Mul => &self.mul, + BinaryOp::Div => &self.div, + }; + unsafe { + stream + .launch_builder(func) + .arg(lhs) + .arg(rhs) + .arg(output) + .arg(&len) + .launch(cfg) + } + .map(|_| ()) + .map_err(map_driver_err) + } + } + } + + /// CUDA kernels expressed as C strings compiled at runtime with NVRTC. + const SOURCE: &str = r#" +// NVRTC treats this source as C++, so we pull in the headers that define the +// math intrinsics (`sqrtf`, `expf`, `logf`) and the `size_t` type that we use for +// element counts. +#include +#include + +// Every kernel follows the exact same thread-indexing pattern: each CUDA thread +// owns a unique flat index into the logical tensor. The math is identical for all +// kernels, so we declare a helper macro to make the generated PTX a little smaller +// and, just as importantly for new CUDA developers, to keep the thread math readable. +#define GLOBAL_INDEX() (blockIdx.x * blockDim.x + threadIdx.x) + +extern "C" __global__ void unary_neg(const float* __restrict__ inp, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = -inp[idx]; + } +} + +extern "C" __global__ void unary_sqrt(const float* __restrict__ inp, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = sqrtf(inp[idx]); + } +} + +extern "C" __global__ void unary_exp(const float* __restrict__ inp, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = expf(inp[idx]); + } +} + +extern "C" __global__ void unary_log(const float* __restrict__ inp, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = logf(inp[idx]); + } +} + +extern "C" __global__ void unary_relu(const float* __restrict__ inp, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + const float v = inp[idx]; + out[idx] = v > 0.0f ? v : 0.0f; + } +} + +extern "C" __global__ void binary_add(const float* __restrict__ lhs, + const float* __restrict__ rhs, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = lhs[idx] + rhs[idx]; + } +} + +extern "C" __global__ void binary_sub(const float* __restrict__ lhs, + const float* __restrict__ rhs, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = lhs[idx] - rhs[idx]; + } +} + +extern "C" __global__ void binary_mul(const float* __restrict__ lhs, + const float* __restrict__ rhs, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = lhs[idx] * rhs[idx]; + } +} + +extern "C" __global__ void binary_div(const float* __restrict__ lhs, + const float* __restrict__ rhs, + float* __restrict__ out, + size_t n) { + const size_t idx = GLOBAL_INDEX(); + if (idx < n) { + out[idx] = lhs[idx] / rhs[idx]; + } +} +"#; +} + +/// Holds the CUDA context, primary stream, and all kernel handles shared by GPU ops. +pub struct CudaBackend { + stream: Arc, + kernels: kernels::Kernels, +} + +impl fmt::Debug for CudaBackend { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CudaBackend") + .field("device", &Device::Cuda) + .finish() + } +} + +impl CudaBackend { + /// Creates a new CUDA backend by initializing the context, default stream, and kernels. + pub fn new() -> BackendResult { + Self::init() + } + + fn init() -> Result { + let device_count = CudaContext::device_count().map_err(map_driver_err)?; + if device_count <= 0 { + return Err(BackendError::CudaUnavailable( + "no CUDA devices detected".to_string(), + )); + } + + // Acquire the primary context for device 0 and keep its default stream alive. + // We intentionally pick device 0 for now; extending this to support multiple + // GPUs just means plumbing an ordinal through the Python API and reusing the + // exact same initialization procedure. + let ctx = CudaContext::new(0).map_err(map_driver_err)?; + let stream = ctx.default_stream(); + + // Load the compiled kernels into the new context. + let kernels = kernels::Kernels::compile(&ctx)?; + + Ok(CudaBackend { stream, kernels }) + } +} + +impl CudaBackend { + fn copy_to_device(&self, view: &View) -> BackendResult> { + // The storage currently lives in host memory. Before we can launch a + // kernel we need to stage that data on the device. `memcpy_stod` + // allocates the device buffer for us and schedules the host-to-device + // transfer onto the stream we cached during initialization. + let slice = view.inner.as_slice(view.offset, view.numel()); + self.stream.memcpy_stod(slice).map_err(map_driver_err) + } + + fn copy_from_device(&self, slice: CudaSlice, shape: &[usize]) -> BackendResult { + // The kernel launch is asynchronous with respect to the host. We force + // completion here so that when we copy the results back we are guaranteed + // to see the writes that the GPU just produced. + self.stream.synchronize().map_err(map_driver_err)?; + // Bring the freshly computed buffer back to the host. The helper allocates + // a `Vec` of the right size and performs a device-to-host copy. + let host = self.stream.memcpy_dtov(&slice).map_err(map_driver_err)?; + // Wrap the vector in our existing `StorageInner` abstraction so the rest + // of the codebase can keep treating the buffer like any other contiguous + // allocation. We still tag it as `Device::Cuda` to preserve the logical + // device placement that the caller requested. + let inner = StorageInner::from_vec(host, Device::Cuda); + Ok(view_from_storage(inner, shape)) + } +} + +impl Backend for CudaBackend { + fn unary(&self, op: UnaryOp, a: &View) -> BackendResult { + require_contiguous(a)?; + + // Upload the source tensor to device memory so the kernel can consume it. + // We do this lazily on every invocation; a production implementation would + // want to keep allocations around, but this keeps the control flow easy to + // follow while still exercising the GPU cores. + let device_in = self.copy_to_device(a)?; + let mut device_out = self + .stream + .alloc_zeros::(a.numel()) + .map_err(map_driver_err)?; + // `alloc_zeros` both reserves device memory and memset's it to 0. The + // zero-fill is not strictly required for our kernels, but it gives us a + // well-defined value in case the GPU ever reads an out-of-bounds index + // (which would indicate a bug elsewhere). + + // Precompute the launch geometry once per kernel dispatch. The helper + // chooses a 1024-thread block size and enough blocks to touch every + // element, which works well for simple pointwise kernels. + let cfg = LaunchConfig::for_num_elems(a.numel() as u32); + let len = a.numel(); + + self.kernels + .launch_unary(&self.stream, op, &device_in, &mut device_out, len, cfg)?; + + self.copy_from_device(device_out, &a.shape) + } + + fn binary(&self, op: BinaryOp, a: &View, b: &View) -> BackendResult { + require_contiguous(a)?; + require_contiguous(b)?; + require_same_numel(a, b)?; + + // Mirror the upload path we used for the unary kernels, but stage both + // operands on the device. + let device_a = self.copy_to_device(a)?; + let device_b = self.copy_to_device(b)?; + let mut device_out = self + .stream + .alloc_zeros::(a.numel()) + .map_err(map_driver_err)?; + // The output buffer is shared across all binary launches to avoid + // allocating multiple temporaries when we chain different kernels for the + // same call. + + // Again, a single helper computes the dispatch geometry. + let cfg = LaunchConfig::for_num_elems(a.numel() as u32); + let len = a.numel(); + + self.kernels.launch_binary( + &self.stream, + op, + &device_a, + &device_b, + &mut device_out, + len, + cfg, + )?; + + self.copy_from_device(device_out, &a.shape) + } +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 18ad865..b85fa2e 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,6 +1,10 @@ use thiserror::Error; +use std::sync::Arc; + +use crate::core::storage::StorageInner; use crate::core::view::{contiguous_strides, View}; +use crate::device::Device; #[derive(Debug, Clone, Copy)] pub enum UnaryOp { @@ -19,12 +23,19 @@ pub enum BinaryOp { Div, } -#[derive(Debug, Error)] +#[derive(Debug, Error, Clone)] pub enum BackendError { #[error("operation requires contiguous views")] NotContiguous, #[error("shape mismatch")] ShapeMismatch, + #[error("backend for device {0:?} is not implemented")] + UnsupportedDevice(Device), + #[cfg(feature = "cuda")] + #[error("cuda runtime error: {0}")] + Cuda(String), + #[error("cuda device unavailable: {0}")] + CudaUnavailable(String), } pub type BackendResult = Result; @@ -34,66 +45,37 @@ pub trait Backend { fn binary(&self, op: BinaryOp, a: &View, b: &View) -> BackendResult; } -// === CPU backend (contiguous-only for now) === - -use std::sync::Arc; - -use crate::core::storage::StorageInner; -use crate::device::Device; - -mod kernels_simd; - -pub struct CpuBackend; +fn require_contiguous(view: &View) -> BackendResult<()> { + if view.is_contiguous() { + Ok(()) + } else { + Err(BackendError::NotContiguous) + } +} -impl Backend for CpuBackend { - fn unary(&self, op: UnaryOp, a: &View) -> BackendResult { - if !a.is_contiguous() { - return Err(BackendError::NotContiguous); - } - let a_slice = a.inner.as_slice(a.offset, a.numel()); - let mut out_inner = StorageInner::new_full(0.0, a.numel(), Device::Cpu); - { - let out_slice = out_inner.as_mut_slice(0, a.numel()); - match op { - UnaryOp::Neg => kernels_simd::unary_ops::neg(a_slice, out_slice), - UnaryOp::Sqrt => kernels_simd::unary_ops::sqrt(a_slice, out_slice), - UnaryOp::Exp => kernels_simd::unary_ops::exp(a_slice, out_slice), - UnaryOp::Log => kernels_simd::unary_ops::log(a_slice, out_slice), - UnaryOp::Relu => kernels_simd::unary_ops::relu(a_slice, out_slice), - } - } - Ok(View { - inner: Arc::new(out_inner), - offset: 0, - shape: a.shape.clone(), - strides: contiguous_strides(&a.shape), - }) +fn require_same_numel(a: &View, b: &View) -> BackendResult<()> { + if a.numel() == b.numel() { + Ok(()) + } else { + Err(BackendError::ShapeMismatch) } +} - fn binary(&self, op: BinaryOp, a: &View, b: &View) -> BackendResult { - if !a.is_contiguous() || !b.is_contiguous() { - return Err(BackendError::NotContiguous); - } - if a.numel() != b.numel() { - return Err(BackendError::ShapeMismatch); - } - let a_slice = a.inner.as_slice(a.offset, a.numel()); - let b_slice = b.inner.as_slice(b.offset, b.numel()); - let mut out_inner = StorageInner::new_full(0.0, a.numel(), Device::Cpu); - { - let out_slice = out_inner.as_mut_slice(0, a.numel()); - match op { - BinaryOp::Add => kernels_simd::binary_ops::add(a_slice, b_slice, out_slice), - BinaryOp::Sub => kernels_simd::binary_ops::sub(a_slice, b_slice, out_slice), - BinaryOp::Mul => kernels_simd::binary_ops::mul(a_slice, b_slice, out_slice), - BinaryOp::Div => kernels_simd::binary_ops::div(a_slice, b_slice, out_slice), - } - } - Ok(View { - inner: Arc::new(out_inner), - offset: 0, - shape: a.shape.clone(), - strides: contiguous_strides(&a.shape), - }) +fn view_from_storage(inner: StorageInner, shape: &[usize]) -> View { + View { + inner: Arc::new(inner), + offset: 0, + shape: shape.to_vec(), + strides: contiguous_strides(shape), } } + +mod cpu; +#[cfg(feature = "cuda")] +mod cuda; + +pub mod registry; + +pub use cpu::CpuBackend; +#[cfg(feature = "cuda")] +pub use cuda::CudaBackend; diff --git a/src/backend/registry.rs b/src/backend/registry.rs new file mode 100644 index 0000000..7e5f9b9 --- /dev/null +++ b/src/backend/registry.rs @@ -0,0 +1,46 @@ +#[cfg(feature = "cuda")] +use super::CudaBackend; +use super::{Backend, BackendError, BackendResult, CpuBackend}; +use crate::device::Device; +#[cfg(feature = "cuda")] +use once_cell::sync::OnceCell; + +static CPU_BACKEND: CpuBackend = CpuBackend; +#[cfg(feature = "cuda")] +static CUDA_BACKEND: OnceCell> = OnceCell::new(); + +#[cfg(feature = "cuda")] +fn cuda_backend() -> BackendResult<&'static CudaBackend> { + CUDA_BACKEND + .get_or_init(|| CudaBackend::new()) + .as_ref() + .map_err(Clone::clone) +} + +pub fn cpu() -> &'static CpuBackend { + &CPU_BACKEND +} + +#[cfg(feature = "cuda")] +pub fn cuda() -> BackendResult<&'static CudaBackend> { + cuda_backend() +} + +pub fn get(device: Device) -> BackendResult<&'static dyn Backend> { + match device { + Device::Cpu => Ok(cpu() as &dyn Backend), + Device::Cuda => { + #[cfg(feature = "cuda")] + { + cuda().map(|backend| backend as &dyn Backend) + } + #[cfg(not(feature = "cuda"))] + { + Err(BackendError::CudaUnavailable( + "crate compiled without the `cuda` feature".to_string(), + )) + } + } + other => Err(BackendError::UnsupportedDevice(other)), + } +} diff --git a/src/py/mod.rs b/src/py/mod.rs index 8550f94..a936939 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use pyo3::prelude::*; -use crate::backend::{Backend, BinaryOp, CpuBackend, UnaryOp}; +use crate::backend::{registry, Backend, BackendError, BinaryOp, UnaryOp}; use crate::core::{ reduce::{self, AxisError}, storage::StorageInner, @@ -21,6 +21,22 @@ fn axis_error_to_py(err: AxisError) -> PyErr { } } +fn backend_err_to_py(err: BackendError) -> PyErr { + pyo3::exceptions::PyRuntimeError::new_err(err.to_string()) +} + +fn backend_for_device(device: Device) -> Result<&'static dyn Backend, PyErr> { + match registry::get(device) { + Ok(backend) => Ok(backend), + Err(BackendError::UnsupportedDevice(device)) => { + Err(pyo3::exceptions::PyNotImplementedError::new_err( + BackendError::UnsupportedDevice(device).to_string(), + )) + } + Err(err) => Err(backend_err_to_py(err)), + } +} + #[cfg(test)] mod tests { use super::*; @@ -278,143 +294,80 @@ impl StorageView { } pub fn neg(&self) -> PyResult { - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .unary(UnaryOp::Neg, &self.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "unary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .unary(UnaryOp::Neg, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn sqrt(&self) -> PyResult { - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .unary(UnaryOp::Sqrt, &self.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "unary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .unary(UnaryOp::Sqrt, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn relu(&self) -> PyResult { - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .unary(UnaryOp::Relu, &self.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "unary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .unary(UnaryOp::Relu, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn exp(&self) -> PyResult { - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .unary(UnaryOp::Exp, &self.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "unary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .unary(UnaryOp::Exp, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn log(&self) -> PyResult { - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .unary(UnaryOp::Log, &self.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "unary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .unary(UnaryOp::Log, &self.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn add(&self, other: &StorageView) -> PyResult { if self.view.inner.device() != other.view.inner.device() { return Err(pyo3::exceptions::PyValueError::new_err("device mismatch")); } - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .binary(BinaryOp::Add, &self.view, &other.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "binary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .binary(BinaryOp::Add, &self.view, &other.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn sub(&self, other: &StorageView) -> PyResult { if self.view.inner.device() != other.view.inner.device() { return Err(pyo3::exceptions::PyValueError::new_err("device mismatch")); } - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .binary(BinaryOp::Sub, &self.view, &other.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "binary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .binary(BinaryOp::Sub, &self.view, &other.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn mul(&self, other: &StorageView) -> PyResult { if self.view.inner.device() != other.view.inner.device() { return Err(pyo3::exceptions::PyValueError::new_err("device mismatch")); } - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .binary(BinaryOp::Mul, &self.view, &other.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "binary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .binary(BinaryOp::Mul, &self.view, &other.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } pub fn div(&self, other: &StorageView) -> PyResult { if self.view.inner.device() != other.view.inner.device() { return Err(pyo3::exceptions::PyValueError::new_err("device mismatch")); } - match self.view.inner.device() { - crate::device::Device::Cpu => { - let backend = CpuBackend; - let out = backend - .binary(BinaryOp::Div, &self.view, &other.view) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; - Ok(StorageView { view: out }) - } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "binary ops not implemented for this device", - )), - } + let backend = backend_for_device(self.view.inner.device())?; + let out = backend + .binary(BinaryOp::Div, &self.view, &other.view) + .map_err(backend_err_to_py)?; + Ok(StorageView { view: out }) } #[pyo3(signature = (dim=None, keepdim=false))] @@ -426,8 +379,8 @@ impl StorageView { let reduced = reduce::reduce_sum(&self.view, axis, keepdim); Ok(StorageView { view: reduced }) } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "sum not implemented for this device", + other => Err(pyo3::exceptions::PyNotImplementedError::new_err( + BackendError::UnsupportedDevice(other).to_string(), )), } } @@ -441,8 +394,8 @@ impl StorageView { let reduced = reduce::reduce_max(&self.view, axis, keepdim); Ok(StorageView { view: reduced }) } - _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( - "max not implemented for this device", + other => Err(pyo3::exceptions::PyNotImplementedError::new_err( + BackendError::UnsupportedDevice(other).to_string(), )), } } diff --git a/src/tests/backend.rs b/src/tests/backend.rs index 7d3200d..56237b5 100644 --- a/src/tests/backend.rs +++ b/src/tests/backend.rs @@ -1,6 +1,8 @@ use std::sync::Arc; -use crate::backend::{Backend, BinaryOp, CpuBackend, UnaryOp}; +#[cfg(feature = "cuda")] +use crate::backend::CudaBackend; +use crate::backend::{registry, Backend, BackendError, BinaryOp, UnaryOp}; use crate::core::{storage::StorageInner, view::View}; use crate::device::Device; @@ -9,10 +11,88 @@ fn vec_view(v: Vec) -> View { View::from_inner_1d(inner) } +#[cfg(feature = "cuda")] +fn cuda_view(v: Vec) -> View { + let inner = Arc::new(StorageInner::from_vec(v, Device::Cuda)); + View::from_inner_1d(inner) +} + +#[cfg(feature = "cuda")] +fn cuda_backend_or_skip(test_name: &str) -> Option<&'static CudaBackend> { + match registry::cuda() { + Ok(be) => Some(be), + Err(err) => { + eprintln!("skipping {test_name}: {err}"); + None + } + } +} + +#[cfg(feature = "cuda")] +fn assert_close(got: f32, expected: f32, tol: f32, idx: usize) { + let diff = (got - expected).abs(); + assert!( + diff <= tol, + "value mismatch at index {idx}: got {got}, expected {expected}, diff {diff} > tol {tol}" + ); +} + +#[cfg(feature = "cuda")] +fn assert_vec_close(actual: &[f32], expected: &[f32], tol: f32) { + assert_eq!(actual.len(), expected.len(), "length mismatch"); + for (idx, (got, exp)) in actual.iter().zip(expected.iter()).enumerate() { + assert_close(*got, *exp, tol, idx); + } +} + +#[cfg(feature = "cuda")] +fn run_cuda_unary_case(name: &str, op: UnaryOp, input: &[f32], reference: F) +where + F: Fn(f32) -> f32, +{ + let backend = match cuda_backend_or_skip(name) { + Some(be) => be, + None => return, + }; + + let view = cuda_view(input.to_vec()); + let out = backend + .unary(op, &view) + .unwrap_or_else(|err| panic!("{} failed: {}", name, err)); + let actual = out.inner.as_slice(out.offset, out.numel()); + let expected: Vec = input.iter().copied().map(reference).collect(); + assert_vec_close(actual, expected.as_slice(), 1e-5); +} + +#[cfg(feature = "cuda")] +fn run_cuda_binary_case(name: &str, op: BinaryOp, lhs: &[f32], rhs: &[f32], reference: F) +where + F: Fn(f32, f32) -> f32, +{ + let backend = match cuda_backend_or_skip(name) { + Some(be) => be, + None => return, + }; + + let a = cuda_view(lhs.to_vec()); + let b = cuda_view(rhs.to_vec()); + let out = backend + .binary(op, &a, &b) + .unwrap_or_else(|err| panic!("{} failed: {}", name, err)); + let actual = out.inner.as_slice(out.offset, out.numel()); + let expected: Vec = lhs + .iter() + .copied() + .zip(rhs.iter().copied()) + .map(|(x, y)| reference(x, y)) + .collect(); + assert_vec_close(actual, expected.as_slice(), 1e-5); +} + #[test] fn unary_neg_matches_scalar() { let a = vec_view(vec![1.0, -2.0, 3.0, -4.5, 0.25]); - let be = CpuBackend; + let be = registry::cpu(); let out = be.unary(UnaryOp::Neg, &a).unwrap(); let actual = out.inner.as_slice(out.offset, out.numel()); let expect: Vec = a @@ -27,7 +107,7 @@ fn unary_neg_matches_scalar() { #[test] fn unary_sqrt_matches_scalar() { let a = vec_view(vec![0.0, 0.25, 1.0, 4.0, 9.0, 16.0, 2.25]); - let be = CpuBackend; + let be = registry::cpu(); let out = be.unary(UnaryOp::Sqrt, &a).unwrap(); let actual = out.inner.as_slice(out.offset, out.numel()); let expect: Vec = a @@ -44,7 +124,7 @@ fn unary_sqrt_matches_scalar() { #[test] fn unary_relu_matches_scalar() { let a = vec_view(vec![-3.0, -0.5, 0.0, 0.5, 2.0, 5.0]); - let be = CpuBackend; + let be = registry::cpu(); let out = be.unary(UnaryOp::Relu, &a).unwrap(); let actual = out.inner.as_slice(out.offset, out.numel()); let expect: Vec = a @@ -62,7 +142,7 @@ fn unary_relu_matches_scalar() { fn binary_add_matches_scalar() { let a = vec_view((0..130).map(|i| i as f32 * 0.5).collect()); // span SIMD + remainder let b = vec_view((0..130).map(|i| i as f32 * -0.25).collect()); - let be = CpuBackend; + let be = registry::cpu(); let out = be.binary(BinaryOp::Add, &a, &b).unwrap(); let actual = out.inner.as_slice(out.offset, out.numel()); let a_s = a.inner.as_slice(a.offset, a.numel()); @@ -76,7 +156,7 @@ fn binary_add_matches_scalar() { fn binary_shape_mismatch_error() { let a = vec_view(vec![1.0, 2.0, 3.0]); let b = vec_view(vec![4.0, 5.0]); - let be = CpuBackend; + let be = registry::cpu(); let err = be.binary(BinaryOp::Add, &a, &b).unwrap_err(); assert!(matches!(err, crate::backend::BackendError::ShapeMismatch)); } @@ -86,7 +166,7 @@ fn unary_not_contiguous_error() { let a = vec_view((0..12).map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); let a_nc = a.permute(&[1, 0]); assert!(!a_nc.is_contiguous()); - let be = CpuBackend; + let be = registry::cpu(); let err = be.unary(UnaryOp::Neg, &a_nc).unwrap_err(); assert!(matches!(err, crate::backend::BackendError::NotContiguous)); } @@ -96,7 +176,172 @@ fn binary_not_contiguous_error() { let a = vec_view((0..12).map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); let b = vec_view((0..12).rev().map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); let a_nc = a.permute(&[1, 0]); - let be = CpuBackend; + let be = registry::cpu(); let err = be.binary(BinaryOp::Add, &a_nc, &b).unwrap_err(); assert!(matches!(err, crate::backend::BackendError::NotContiguous)); } + +#[test] +fn registry_get_cpu_backend() { + let backend = registry::get(Device::Cpu).unwrap(); + let view = vec_view(vec![1.0, -2.0, 3.0]); + let out = backend.unary(UnaryOp::Neg, &view).unwrap(); + assert_eq!( + out.inner.as_slice(out.offset, out.numel()), + &[-1.0, 2.0, -3.0] + ); +} + +#[test] +fn registry_reports_unsupported_device() { + let err = match registry::get(Device::Metal) { + Ok(_) => panic!("expected registry lookup to fail for Metal"), + Err(err) => err, + }; + assert!(matches!( + err, + BackendError::UnsupportedDevice(Device::Metal) + )); +} + +#[cfg(not(feature = "cuda"))] +#[test] +fn registry_reports_cuda_unavailable_when_feature_disabled() { + let err = match registry::get(Device::Cuda) { + Ok(_) => panic!("expected CUDA backend lookup to fail without feature"), + Err(err) => err, + }; + assert!(matches!( + err, + BackendError::CudaUnavailable(message) if message.contains("`cuda` feature") + )); +} + +#[cfg(feature = "cuda")] +#[test] +fn cuda_unary_neg_matches_cpu_when_device_available() { + run_cuda_unary_case( + "cuda_unary_neg", + UnaryOp::Neg, + &[1.0, -3.0, 0.5, 7.25], + |x| -x, + ); +} + +#[cfg(feature = "cuda")] +#[test] +fn cuda_unary_sqrt_matches_cpu_when_device_available() { + run_cuda_unary_case( + "cuda_unary_sqrt", + UnaryOp::Sqrt, + &[0.0, 0.25, 1.0, 4.0, 9.0], + |x| x.sqrt(), + ); +} + +#[cfg(feature = "cuda")] +#[test] +fn cuda_unary_exp_matches_cpu_when_device_available() { + run_cuda_unary_case( + "cuda_unary_exp", + UnaryOp::Exp, + &[1.0, -2.5, 3.25, -4.75, 0.0], + |x| x.exp(), + ); +} + +#[cfg(feature = "cuda")] +#[test] +fn cuda_unary_log_matches_cpu_when_device_available() { + run_cuda_unary_case( + "cuda_unary_log", + UnaryOp::Log, + &[0.25, 1.0, 2.5, 10.0], + |x| x.ln(), + ); +} + +#[cfg(feature = "cuda")] +#[test] +fn cuda_unary_relu_matches_cpu_when_device_available() { + run_cuda_unary_case( + "cuda_unary_relu", + UnaryOp::Relu, + &[-3.0, -0.5, 0.0, 0.5, 2.0, 5.0], + |x| x.max(0.0), + ); +} + +#[cfg(feature = "cuda")] +#[test] +fn cuda_binary_add_matches_cpu_when_device_available() { + let lhs: Vec = (0..64).map(|i| i as f32 * 0.5).collect(); + let rhs: Vec = (0..64).map(|i| i as f32 * -0.25).collect(); + run_cuda_binary_case("cuda_binary_add", BinaryOp::Add, &lhs, &rhs, |x, y| x + y); +} + +#[cfg(feature = "cuda")] +#[test] +fn cuda_binary_sub_matches_cpu_when_device_available() { + let lhs: Vec = (0..64).map(|i| i as f32 * 0.5).collect(); + let rhs: Vec = (0..64).map(|i| i as f32 * -0.25).collect(); + run_cuda_binary_case("cuda_binary_sub", BinaryOp::Sub, &lhs, &rhs, |x, y| x - y); +} + +#[cfg(feature = "cuda")] +#[test] +fn cuda_binary_mul_matches_cpu_when_device_available() { + let lhs: Vec = (0..64).map(|i| i as f32 * 0.5).collect(); + let rhs: Vec = (0..64).map(|i| (i as f32 + 1.0) * 0.1).collect(); + run_cuda_binary_case("cuda_binary_mul", BinaryOp::Mul, &lhs, &rhs, |x, y| x * y); +} + +#[cfg(feature = "cuda")] +#[test] +fn cuda_binary_div_matches_cpu_when_device_available() { + let lhs: Vec = (1..65).map(|i| i as f32).collect(); + let rhs: Vec = (1..65).map(|i| (i as f32) * 0.5 + 1.0).collect(); + run_cuda_binary_case("cuda_binary_div", BinaryOp::Div, &lhs, &rhs, |x, y| x / y); +} + +#[cfg(feature = "cuda")] +#[test] +fn cuda_unary_not_contiguous_error_when_device_available() { + let backend = match cuda_backend_or_skip("cuda_unary_not_contiguous_error") { + Some(be) => be, + None => return, + }; + let view = cuda_view((0..12).map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); + let non_contig = view.permute(&[1, 0]); + let err = backend.unary(UnaryOp::Neg, &non_contig).unwrap_err(); + assert!(matches!(err, crate::backend::BackendError::NotContiguous)); +} + +#[cfg(feature = "cuda")] +#[test] +fn cuda_binary_not_contiguous_error_when_device_available() { + let backend = match cuda_backend_or_skip("cuda_binary_not_contiguous_error") { + Some(be) => be, + None => return, + }; + let view_a = cuda_view((0..12).map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); + let view_b = cuda_view((0..12).rev().map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); + let non_contig = view_a.permute(&[1, 0]); + let err = backend + .binary(BinaryOp::Add, &non_contig, &view_b) + .unwrap_err(); + assert!(matches!(err, crate::backend::BackendError::NotContiguous)); +} + +#[cfg(feature = "cuda")] +#[test] +fn cuda_binary_shape_mismatch_error_when_device_available() { + let backend = match cuda_backend_or_skip("cuda_binary_shape_mismatch_error") { + Some(be) => be, + None => return, + }; + let a = cuda_view(vec![1.0, 2.0, 3.0]); + let b = cuda_view(vec![4.0, 5.0]); + let err = backend.binary(BinaryOp::Add, &a, &b).unwrap_err(); + assert!(matches!(err, crate::backend::BackendError::ShapeMismatch)); +}