diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 32ac2b0..27e6519 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -28,7 +28,7 @@ jobs: - uses: dtolnay/rust-toolchain@nightly with: - toolchain: nightly-2024-08-15 + toolchain: nightly components: rustfmt, clippy - uses: Swatinem/rust-cache@v2 @@ -68,7 +68,7 @@ jobs: - uses: dtolnay/rust-toolchain@nightly with: - toolchain: nightly-2024-08-15 + toolchain: nightly - uses: Swatinem/rust-cache@v2 - name: Run tests diff --git a/Cargo.lock b/Cargo.lock index 263eefe..5c95aa0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,33 +1,28 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "autocfg" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" - -[[package]] -name = "byteorder" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "cc" -version = "1.1.20" +version = "1.2.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45bcde016d64c21da4be18b655631e5ab6d3107607e71a73a9f53eb48aae23fb" +checksum = "5252b3d2648e5eedbc1a6f501e3c795e07025c1e93bbf8bbdd6eef7f447a6d54" dependencies = [ + "find-msvc-tools", "shlex", ] [[package]] name = "cfg-if" -version = "1.0.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" [[package]] name = "cranberry" @@ -35,13 +30,20 @@ version = "0.1.3" dependencies = [ "pyo3", "rand", + "thiserror", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fd99930f64d146689264c637b5af2f0233a933bef0d8570e2526bf9e083192d" + [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "libc", @@ -56,15 +58,15 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "indoc" -version = "2.0.5" +version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" [[package]] name = "libc" -version = "0.2.158" +version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" +checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" [[package]] name = "memoffset" @@ -77,39 +79,39 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.19.0" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "portable-atomic" -version = "1.7.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" [[package]] name = "ppv-lite86" -version = "0.2.20" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ "zerocopy", ] [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" dependencies = [ "unicode-ident", ] [[package]] name = "pyo3" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17da310086b068fbdcefbba30aeb3721d5bb9af8db4987d6735b2183ca567229" +checksum = "e5203598f366b11a02b13aa20cab591229ff0a89fd121a308a5df751d5fc9219" dependencies = [ "cfg-if", "indoc", @@ -125,9 +127,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e27165889bd793000a098bb966adc4300c312497ea25cf7a690a9f0ac5aa5fc1" +checksum = "99636d423fa2ca130fa5acde3059308006d46f98caac629418e53f7ebb1e9999" dependencies = [ "once_cell", "python3-dll-a", @@ -136,9 +138,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05280526e1dbf6b420062f3ef228b78c0c54ba94e157f5cb724a609d0f2faabc" +checksum = "78f9cf92ba9c409279bc3305b5409d90db2d2c22392d443a87df3a1adad59e33" dependencies = [ "libc", "pyo3-build-config", @@ -146,9 +148,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c3ce5686aa4d3f63359a5100c62a127c9f15e8398e5fdeb5deef1fed5cd5f44" +checksum = "0b999cb1a6ce21f9a6b147dcf1be9ffedf02e0043aec74dc390f3007047cecd9" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -158,9 +160,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4cf6faa0cbfb0ed08e89beb8103ae9724eb4750e3a78084ba4017cbe94f3855" +checksum = "822ece1c7e1012745607d5cf0bcb2874769f0f7cb34c4cde03b9358eb9ef911a" dependencies = [ "heck", "proc-macro2", @@ -180,9 +182,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.37" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" dependencies = [ "proc-macro2", ] @@ -225,9 +227,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "syn" -version = "2.0.77" +version = "2.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" dependencies = [ "proc-macro2", "quote", @@ -240,39 +242,58 @@ version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df7f62577c25e07834649fc3b39fafdc597c0a3527dc1c60129201ccfcbaa50c" +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-ident" -version = "1.0.13" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d" [[package]] name = "unindent" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" [[package]] name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "zerocopy" -version = "0.7.35" +version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" dependencies = [ - "byteorder", "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.35" +version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 2d07bd2..ada02ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,11 +5,12 @@ edition = "2021" [lib] name = "cranberry" -crate-type = ["cdylib"] +crate-type = ["cdylib", "rlib"] [dependencies] pyo3 = { version = "0.24.1", features = ["extension-module"] } rand = { version = "0.8.5" } +thiserror = "1" [features] abi3 = ["pyo3/abi3-py37", "generate-import-lib"] diff --git a/cranberry/__init__.py b/cranberry/__init__.py index 0285539..ed6625b 100644 --- a/cranberry/__init__.py +++ b/cranberry/__init__.py @@ -1,2 +1,2 @@ from cranberry.tensor import Tensor as Tensor -from .cranberry import StoragePtr as StoragePtr +from .cranberry import StorageView as StorageView diff --git a/cranberry/cranberry.pyi b/cranberry/cranberry.pyi index 965b688..1a87693 100644 --- a/cranberry/cranberry.pyi +++ b/cranberry/cranberry.pyi @@ -1,29 +1,22 @@ from typing import Union import numpy as np -class StoragePtr: +class StorageView: @staticmethod - def full(value: float, size: int, device: str) -> StoragePtr: ... + def full(value: float, size: int, device: str) -> StorageView: ... @staticmethod - def from_vec(vec: Union[list[float], np.ndarray], device: str) -> StoragePtr: ... - @staticmethod - def neg(a: StoragePtr, b: StoragePtr, idx_a: int, idx_b: int, size: int): ... - @staticmethod - def sqrt(a: StoragePtr, b: StoragePtr, idx_a: int, idx_b: int, size: int): ... - @staticmethod - def exp(a: StoragePtr, b: StoragePtr, idx_a: int, idx_b: int, size: int): ... - @staticmethod - def log(a: StoragePtr, b: StoragePtr, idx_a: int, idx_b: int, size: int): ... - @staticmethod - def add(a: StoragePtr, b: StoragePtr, c: StoragePtr, idx_a: int, idx_b: int, idx_c: int, size: int): ... - @staticmethod - def sub(a: StoragePtr, b: StoragePtr, c: StoragePtr, idx_a: int, idx_b: int, idx_c: int, size: int): ... - @staticmethod - def mul(a: StoragePtr, b: StoragePtr, c: StoragePtr, idx_a: int, idx_b: int, idx_c: int, size: int): ... - @staticmethod - def div(a: StoragePtr, b: StoragePtr, c: StoragePtr, idx_a: int, idx_b: int, idx_c: int, size: int): ... - @staticmethod - def sum(a: StoragePtr, b: StoragePtr, idx_a: int, idx_b: int, size: int): ... - @staticmethod - def max(a: StoragePtr, b: StoragePtr, idx_a: int, idx_b: int, size: int): ... + def from_vec(vec: Union[list[float], np.ndarray], device: str) -> StorageView: ... + def len(self) -> int: ... def to_vec(self) -> list[float]: ... + def slice(self, offset: int, size: int) -> StorageView: ... + def reshape(self, shape: list[int]) -> StorageView: ... + def expand(self, shape: list[int]) -> StorageView: ... + def permute(self, dims: list[int]) -> StorageView: ... + def neg(self) -> StorageView: ... + def sqrt(self) -> StorageView: ... + def exp(self) -> StorageView: ... + def log(self) -> StorageView: ... + def add(self, other: StorageView) -> StorageView: ... + def sub(self, other: StorageView) -> StorageView: ... + def mul(self, other: StorageView) -> StorageView: ... + def div(self, other: StorageView) -> StorageView: ... diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 15b2405..5d56faf 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-08-15" +channel = "nightly" diff --git a/src/backend/kernels_simd.rs b/src/backend/kernels_simd.rs new file mode 100644 index 0000000..c63a07e --- /dev/null +++ b/src/backend/kernels_simd.rs @@ -0,0 +1,157 @@ +use std::simd::{f32x64, StdFloat}; + +const CHUNK_SIZE: usize = 64; + +pub mod unary_ops { + use super::*; + use std::ops::Neg; + + pub fn neg(a: &[f32], b: &mut [f32]) { + assert!(a.len() == b.len()); + let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); + let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); + a_main + .chunks_exact(CHUNK_SIZE) + .zip(b_main.chunks_exact_mut(CHUNK_SIZE)) + .for_each(|(a, b)| f32x64::from_slice(a).neg().copy_to_slice(b)); + a_rem + .iter() + .zip(b_rem.iter_mut()) + .for_each(|(a, b)| *b = -a); + } + + pub fn sqrt(a: &[f32], b: &mut [f32]) { + assert!(a.len() == b.len()); + let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); + let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); + a_main + .chunks_exact(CHUNK_SIZE) + .zip(b_main.chunks_exact_mut(CHUNK_SIZE)) + .for_each(|(a, b)| f32x64::from_slice(a).sqrt().copy_to_slice(b)); + a_rem + .iter() + .zip(b_rem.iter_mut()) + .for_each(|(a, b)| *b = a.sqrt()); + } + + pub fn exp(a: &[f32], b: &mut [f32]) { + assert!(a.len() == b.len()); + let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); + let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); + a_main + .chunks_exact(CHUNK_SIZE) + .zip(b_main.chunks_exact_mut(CHUNK_SIZE)) + .for_each(|(a, b)| f32x64::from_slice(a).exp().copy_to_slice(b)); + a_rem + .iter() + .zip(b_rem.iter_mut()) + .for_each(|(a, b)| *b = a.exp()); + } + + pub fn log(a: &[f32], b: &mut [f32]) { + assert!(a.len() == b.len()); + let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE); + let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE); + a_main + .chunks_exact(CHUNK_SIZE) + .zip(b_main.chunks_exact_mut(CHUNK_SIZE)) + .for_each(|(a, b)| f32x64::from_slice(a).ln().copy_to_slice(b)); + a_rem + .iter() + .zip(b_rem.iter_mut()) + .for_each(|(a, b)| *b = a.ln()); + } +} + +pub mod binary_ops { + use super::*; + use std::ops::{Add, Div, Mul, Sub}; + + pub fn add(a: &[f32], b: &[f32], c: &mut [f32]) { + assert!(a.len() == b.len() && b.len() == c.len()); + let main = a.len() - a.len() % CHUNK_SIZE; + let (a_main, a_rem) = a.split_at(main); + let (b_main, b_rem) = b.split_at(main); + let (c_main, c_rem) = c.split_at_mut(main); + a_main + .chunks_exact(CHUNK_SIZE) + .zip(b_main.chunks_exact(CHUNK_SIZE)) + .zip(c_main.chunks_exact_mut(CHUNK_SIZE)) + .for_each(|((a, b), c)| { + f32x64::from_slice(a) + .add(f32x64::from_slice(b)) + .copy_to_slice(c) + }); + a_rem + .iter() + .zip(b_rem.iter()) + .zip(c_rem.iter_mut()) + .for_each(|((a, b), c)| *c = a + b); + } + + pub fn sub(a: &[f32], b: &[f32], c: &mut [f32]) { + assert!(a.len() == b.len() && b.len() == c.len()); + let main = a.len() - a.len() % CHUNK_SIZE; + let (a_main, a_rem) = a.split_at(main); + let (b_main, b_rem) = b.split_at(main); + let (c_main, c_rem) = c.split_at_mut(main); + a_main + .chunks_exact(CHUNK_SIZE) + .zip(b_main.chunks_exact(CHUNK_SIZE)) + .zip(c_main.chunks_exact_mut(CHUNK_SIZE)) + .for_each(|((a, b), c)| { + f32x64::from_slice(a) + .sub(f32x64::from_slice(b)) + .copy_to_slice(c) + }); + a_rem + .iter() + .zip(b_rem.iter()) + .zip(c_rem.iter_mut()) + .for_each(|((a, b), c)| *c = a - b); + } + + pub fn mul(a: &[f32], b: &[f32], c: &mut [f32]) { + assert!(a.len() == b.len() && b.len() == c.len()); + let main = a.len() - a.len() % CHUNK_SIZE; + let (a_main, a_rem) = a.split_at(main); + let (b_main, b_rem) = b.split_at(main); + let (c_main, c_rem) = c.split_at_mut(main); + a_main + .chunks_exact(CHUNK_SIZE) + .zip(b_main.chunks_exact(CHUNK_SIZE)) + .zip(c_main.chunks_exact_mut(CHUNK_SIZE)) + .for_each(|((a, b), c)| { + f32x64::from_slice(a) + .mul(f32x64::from_slice(b)) + .copy_to_slice(c) + }); + a_rem + .iter() + .zip(b_rem.iter()) + .zip(c_rem.iter_mut()) + .for_each(|((a, b), c)| *c = a * b); + } + + pub fn div(a: &[f32], b: &[f32], c: &mut [f32]) { + assert!(a.len() == b.len() && b.len() == c.len()); + let main = a.len() - a.len() % CHUNK_SIZE; + let (a_main, a_rem) = a.split_at(main); + let (b_main, b_rem) = b.split_at(main); + let (c_main, c_rem) = c.split_at_mut(main); + a_main + .chunks_exact(CHUNK_SIZE) + .zip(b_main.chunks_exact(CHUNK_SIZE)) + .zip(c_main.chunks_exact_mut(CHUNK_SIZE)) + .for_each(|((a, b), c)| { + f32x64::from_slice(a) + .div(f32x64::from_slice(b)) + .copy_to_slice(c) + }); + a_rem + .iter() + .zip(b_rem.iter()) + .zip(c_rem.iter_mut()) + .for_each(|((a, b), c)| *c = a / b); + } +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs new file mode 100644 index 0000000..4eb19ee --- /dev/null +++ b/src/backend/mod.rs @@ -0,0 +1,111 @@ +use thiserror::Error; + +use crate::core::view::View; + +#[derive(Debug, Clone, Copy)] +pub enum UnaryOp { + Neg, + Sqrt, + Exp, + Log, +} + +#[derive(Debug, Clone, Copy)] +pub enum BinaryOp { + Add, + Sub, + Mul, + Div, +} + +#[derive(Debug, Error)] +pub enum BackendError { + #[error("operation requires contiguous views")] + NotContiguous, + #[error("shape mismatch")] + ShapeMismatch, +} + +pub type BackendResult = Result; + +pub trait Backend { + fn unary(&self, op: UnaryOp, a: &View) -> BackendResult; + fn binary(&self, op: BinaryOp, a: &View, b: &View) -> BackendResult; +} + +// === CPU backend (contiguous-only for now) === + +use std::sync::Arc; + +use crate::core::storage::StorageInner; +use crate::device::Device; + +mod kernels_simd; + +pub struct CpuBackend; + +impl Backend for CpuBackend { + fn unary(&self, op: UnaryOp, a: &View) -> BackendResult { + if !a.is_contiguous() { + return Err(BackendError::NotContiguous); + } + let a_slice = a.inner.as_slice(a.offset, a.numel()); + let mut out_inner = StorageInner::new_full(0.0, a.numel(), Device::Cpu); + { + let out_slice = out_inner.as_mut_slice(0, a.numel()); + match op { + UnaryOp::Neg => kernels_simd::unary_ops::neg(a_slice, out_slice), + UnaryOp::Sqrt => kernels_simd::unary_ops::sqrt(a_slice, out_slice), + UnaryOp::Exp => kernels_simd::unary_ops::exp(a_slice, out_slice), + UnaryOp::Log => kernels_simd::unary_ops::log(a_slice, out_slice), + } + } + let mut stride = 1isize; + let mut strides = vec![0isize; a.shape.len()]; + for (i, dim) in a.shape.iter().rev().enumerate() { + let idx = a.shape.len() - 1 - i; + strides[idx] = stride; + stride *= *dim as isize; + } + Ok(View { + inner: Arc::new(out_inner), + offset: 0, + shape: a.shape.clone(), + strides, + }) + } + + fn binary(&self, op: BinaryOp, a: &View, b: &View) -> BackendResult { + if !a.is_contiguous() || !b.is_contiguous() { + return Err(BackendError::NotContiguous); + } + if a.numel() != b.numel() { + return Err(BackendError::ShapeMismatch); + } + let a_slice = a.inner.as_slice(a.offset, a.numel()); + let b_slice = b.inner.as_slice(b.offset, b.numel()); + let mut out_inner = StorageInner::new_full(0.0, a.numel(), Device::Cpu); + { + let out_slice = out_inner.as_mut_slice(0, a.numel()); + match op { + BinaryOp::Add => kernels_simd::binary_ops::add(a_slice, b_slice, out_slice), + BinaryOp::Sub => kernels_simd::binary_ops::sub(a_slice, b_slice, out_slice), + BinaryOp::Mul => kernels_simd::binary_ops::mul(a_slice, b_slice, out_slice), + BinaryOp::Div => kernels_simd::binary_ops::div(a_slice, b_slice, out_slice), + } + } + let mut stride = 1isize; + let mut strides = vec![0isize; a.shape.len()]; + for (i, dim) in a.shape.iter().rev().enumerate() { + let idx = a.shape.len() - 1 - i; + strides[idx] = stride; + stride *= *dim as isize; + } + Ok(View { + inner: Arc::new(out_inner), + offset: 0, + shape: a.shape.clone(), + strides, + }) + } +} diff --git a/src/core/mod.rs b/src/core/mod.rs new file mode 100644 index 0000000..c65f10b --- /dev/null +++ b/src/core/mod.rs @@ -0,0 +1,2 @@ +pub mod storage; +pub mod view; diff --git a/src/core/storage.rs b/src/core/storage.rs new file mode 100644 index 0000000..b7b1fb4 --- /dev/null +++ b/src/core/storage.rs @@ -0,0 +1,41 @@ +use crate::device::Device; + +/// Owns the raw allocation backing one or more `View`s. +#[derive(Debug)] +pub struct StorageInner { + pub(crate) data: Vec, + pub(crate) device: Device, +} + +impl StorageInner { + pub fn new_full(value: f32, size: usize, device: Device) -> Self { + Self { + data: vec![value; size], + device, + } + } + + pub fn from_vec(vec: Vec, device: Device) -> Self { + Self { data: vec, device } + } + + pub fn len(&self) -> usize { + self.data.len() + } + + pub fn device(&self) -> Device { + self.device.clone() + } + + pub fn as_slice(&self, offset: usize, size: usize) -> &[f32] { + &self.data[offset..offset + size] + } + + pub fn as_mut_slice(&mut self, offset: usize, size: usize) -> &mut [f32] { + &mut self.data[offset..offset + size] + } + + pub fn to_vec(&self, offset: usize, size: usize) -> Vec { + self.as_slice(offset, size).to_vec() + } +} diff --git a/src/core/view.rs b/src/core/view.rs new file mode 100644 index 0000000..e2a19ab --- /dev/null +++ b/src/core/view.rs @@ -0,0 +1,123 @@ +use std::sync::Arc; + +use super::storage::StorageInner; + +#[derive(Clone, Debug)] +pub struct View { + pub(crate) inner: Arc, + pub(crate) offset: usize, + pub(crate) shape: Vec, + pub(crate) strides: Vec, +} + +impl View { + pub fn from_inner_1d(inner: Arc) -> Self { + let len = inner.len(); + Self { + inner, + offset: 0, + shape: vec![len], + strides: vec![1], + } + } + + pub fn numel(&self) -> usize { + self.shape.iter().product() + } + + pub fn is_contiguous(&self) -> bool { + let mut expected: isize = 1; + for (&dim, &stride) in self.shape.iter().rev().zip(self.strides.iter().rev()) { + if dim == 1 { + continue; + } + if stride != expected { + return false; + } + expected *= dim as isize; + } + true + } + + pub fn slice_1d(&self, offset: usize, len: usize) -> Self { + debug_assert_eq!(self.shape.len(), 1, "slice_1d expects a 1D view"); + debug_assert!(offset + len <= self.numel()); + let mut out = self.clone(); + out.offset = self.offset + offset; + out.shape = vec![len]; + out.strides = vec![1]; + out + } + + /// Reshape a contiguous view to `new_shape` (row-major). Number of elements must match. + pub fn reshape_contiguous(&self, new_shape: &[usize]) -> Self { + assert!(self.is_contiguous(), "reshape requires contiguous view"); + assert_eq!( + self.numel(), + new_shape.iter().copied().product::(), + "reshape size mismatch" + ); + let mut stride = 1isize; + let mut strides = vec![0isize; new_shape.len()]; + for (i, dim) in new_shape.iter().rev().enumerate() { + let idx = new_shape.len() - 1 - i; + strides[idx] = stride; + stride *= *dim as isize; + } + Self { + inner: self.inner.clone(), + offset: self.offset, + shape: new_shape.to_vec(), + strides, + } + } + + /// Permute axes by `dims`. Metadata-only; strides/shape are reordered. + pub fn permute(&self, dims: &[usize]) -> Self { + assert_eq!(dims.len(), self.shape.len()); + let mut seen = vec![false; dims.len()]; + for &d in dims { + assert!(d < dims.len() && !seen[d]); + seen[d] = true; + } + let shape = dims.iter().map(|&i| self.shape[i]).collect::>(); + let strides = dims.iter().map(|&i| self.strides[i]).collect::>(); + Self { + inner: self.inner.clone(), + offset: self.offset, + shape, + strides, + } + } + + /// Broadcast this view to `new_shape` by setting stride 0 on expanded axes. + pub fn expand(&self, new_shape: &[usize]) -> Self { + let mut s_shape = self.shape.clone(); + while s_shape.len() < new_shape.len() { + s_shape.insert(0, 1); + } + let mut s_strides = { + let mut ss = self.strides.clone(); + while ss.len() < new_shape.len() { + ss.insert(0, 0); + } + ss + }; + for i in 0..new_shape.len() { + let need = new_shape[i]; + let have = s_shape[i]; + assert!(have == need || have == 1, "expand incompatible at dim {i}"); + if have == 1 && need > 1 { + s_strides[i] = 0; + } + } + Self { + inner: self.inner.clone(), + offset: self.offset, + shape: new_shape.to_vec(), + strides: s_strides, + } + } + + // TODO: general strided iteration, overlap/alias checks. +} diff --git a/src/lib.rs b/src/lib.rs index 7e3616c..7891b73 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,17 @@ #![feature(portable_simd)] -#![feature(array_chunks)] +mod backend; +mod core; mod device; -mod storage; -mod storage_ptr; +mod py; use pyo3::prelude::*; +#[cfg(test)] +mod tests; + #[pyo3::pymodule] fn cranberry(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/src/py/mod.rs b/src/py/mod.rs new file mode 100644 index 0000000..9f554ab --- /dev/null +++ b/src/py/mod.rs @@ -0,0 +1,200 @@ +use std::sync::Arc; + +use pyo3::prelude::*; + +use crate::backend::{Backend, BinaryOp, CpuBackend, UnaryOp}; +use crate::core::{storage::StorageInner, view::View}; +use crate::device::Device; + +#[pyo3::pyclass] +#[derive(Clone)] +pub struct StorageView { + view: View, +} + +#[pyo3::pymethods] +impl StorageView { + #[staticmethod] + pub fn from_vec(vec: Vec, device: &str) -> PyResult { + let inner = StorageInner::from_vec(vec, Device::from_str(device)); + Ok(Self { + view: View::from_inner_1d(Arc::new(inner)), + }) + } + + #[staticmethod] + pub fn full(value: f32, size: usize, device: &str) -> PyResult { + let inner = StorageInner::new_full(value, size, Device::from_str(device)); + Ok(Self { + view: View::from_inner_1d(Arc::new(inner)), + }) + } + + pub fn len(&self) -> usize { + self.view.numel() + } + + pub fn to_vec(&self) -> PyResult> { + if !self.view.is_contiguous() { + return Err(pyo3::exceptions::PyValueError::new_err( + "to_vec only supports contiguous views for now", + )); + } + Ok(self.view.inner.to_vec(self.view.offset, self.view.numel())) + } + + pub fn slice(&self, offset: usize, size: usize) -> PyResult { + if self.view.shape.len() != 1 { + return Err(pyo3::exceptions::PyValueError::new_err( + "slice(offset, size) is only available for 1D views", + )); + } + Ok(StorageView { + view: self.view.slice_1d(offset, size), + }) + } + + pub fn reshape(&self, shape: Vec) -> PyResult { + Ok(StorageView { + view: self.view.reshape_contiguous(&shape), + }) + } + + pub fn expand(&self, shape: Vec) -> PyResult { + Ok(StorageView { + view: self.view.expand(&shape), + }) + } + + pub fn permute(&self, dims: Vec) -> PyResult { + Ok(StorageView { + view: self.view.permute(&dims), + }) + } + + pub fn neg(&self) -> PyResult { + match self.view.inner.device() { + crate::device::Device::Cpu => { + let backend = CpuBackend; + let out = backend + .unary(UnaryOp::Neg, &self.view) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Ok(StorageView { view: out }) + } + _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( + "unary ops not implemented for this device", + )), + } + } + pub fn sqrt(&self) -> PyResult { + match self.view.inner.device() { + crate::device::Device::Cpu => { + let backend = CpuBackend; + let out = backend + .unary(UnaryOp::Sqrt, &self.view) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Ok(StorageView { view: out }) + } + _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( + "unary ops not implemented for this device", + )), + } + } + pub fn exp(&self) -> PyResult { + match self.view.inner.device() { + crate::device::Device::Cpu => { + let backend = CpuBackend; + let out = backend + .unary(UnaryOp::Exp, &self.view) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Ok(StorageView { view: out }) + } + _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( + "unary ops not implemented for this device", + )), + } + } + pub fn log(&self) -> PyResult { + match self.view.inner.device() { + crate::device::Device::Cpu => { + let backend = CpuBackend; + let out = backend + .unary(UnaryOp::Log, &self.view) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Ok(StorageView { view: out }) + } + _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( + "unary ops not implemented for this device", + )), + } + } + + pub fn add(&self, other: &StorageView) -> PyResult { + if self.view.inner.device() != other.view.inner.device() { + return Err(pyo3::exceptions::PyValueError::new_err("device mismatch")); + } + match self.view.inner.device() { + crate::device::Device::Cpu => { + let backend = CpuBackend; + let out = backend + .binary(BinaryOp::Add, &self.view, &other.view) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Ok(StorageView { view: out }) + } + _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( + "binary ops not implemented for this device", + )), + } + } + pub fn sub(&self, other: &StorageView) -> PyResult { + if self.view.inner.device() != other.view.inner.device() { + return Err(pyo3::exceptions::PyValueError::new_err("device mismatch")); + } + match self.view.inner.device() { + crate::device::Device::Cpu => { + let backend = CpuBackend; + let out = backend + .binary(BinaryOp::Sub, &self.view, &other.view) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Ok(StorageView { view: out }) + } + _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( + "binary ops not implemented for this device", + )), + } + } + pub fn mul(&self, other: &StorageView) -> PyResult { + if self.view.inner.device() != other.view.inner.device() { + return Err(pyo3::exceptions::PyValueError::new_err("device mismatch")); + } + match self.view.inner.device() { + crate::device::Device::Cpu => { + let backend = CpuBackend; + let out = backend + .binary(BinaryOp::Mul, &self.view, &other.view) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Ok(StorageView { view: out }) + } + _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( + "binary ops not implemented for this device", + )), + } + } + pub fn div(&self, other: &StorageView) -> PyResult { + if self.view.inner.device() != other.view.inner.device() { + return Err(pyo3::exceptions::PyValueError::new_err("device mismatch")); + } + match self.view.inner.device() { + crate::device::Device::Cpu => { + let backend = CpuBackend; + let out = backend + .binary(BinaryOp::Div, &self.view, &other.view) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Ok(StorageView { view: out }) + } + _ => Err(pyo3::exceptions::PyNotImplementedError::new_err( + "binary ops not implemented for this device", + )), + } + } +} diff --git a/src/storage/cpu_backend.rs b/src/storage/cpu_backend.rs deleted file mode 100644 index 0f099eb..0000000 --- a/src/storage/cpu_backend.rs +++ /dev/null @@ -1,208 +0,0 @@ -use std::simd::{f32x64, StdFloat}; - -const CHUNK_SIZE: usize = 64; - -pub mod unary_ops { - use std::ops::Neg; - - use super::*; - #[inline(always)] - pub fn neg(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); - - a.array_chunks::() - .map(|a| f32x64::from_array(*a)) - .zip(b.array_chunks_mut::()) - .for_each(|(a, b)| { - a.neg().copy_to_slice(b); - }); - - let remain = a.len() - a.len() % CHUNK_SIZE; - a[remain..].iter().zip(&mut b[remain..]).for_each(|(a, b)| { - *b = -a; - }); - } - #[inline(always)] - pub fn sqrt(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); - - a.array_chunks::() - .map(|a| f32x64::from_array(*a)) - .zip(b.array_chunks_mut::()) - .for_each(|(a, b)| { - a.sqrt().copy_to_slice(b); - }); - - let remain = a.len() - a.len() % CHUNK_SIZE; - a[remain..].iter().zip(&mut b[remain..]).for_each(|(a, b)| { - *b = a.sqrt(); - }); - } - #[inline(always)] - pub fn exp(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); - - a.array_chunks::() - .map(|a| f32x64::from_array(*a)) - .zip(b.array_chunks_mut::()) - .for_each(|(a, b)| { - a.exp().copy_to_slice(b); - }); - - let remain = a.len() - a.len() % CHUNK_SIZE; - a[remain..].iter().zip(&mut b[remain..]).for_each(|(a, b)| { - *b = a.exp(); - }); - } - #[inline(always)] - pub fn log(a: &[f32], b: &mut [f32]) { - assert!(a.len() == b.len()); - - a.array_chunks::() - .map(|a| f32x64::from_array(*a)) - .zip(b.array_chunks_mut::()) - .for_each(|(a, b)| { - a.ln().copy_to_slice(b); - }); - - let remain = a.len() - a.len() % CHUNK_SIZE; - a[remain..].iter().zip(&mut b[remain..]).for_each(|(a, b)| { - *b = a.ln(); - }); - } -} - -pub mod binary_ops { - use std::ops::{Add, Div, Mul, Sub}; - - use super::*; - #[inline(always)] - pub fn add(a: &[f32], b: &[f32], c: &mut [f32]) { - assert!(a.len() == b.len() && b.len() == c.len()); - - a.array_chunks::() - .map(|a| f32x64::from_slice(a)) - .zip( - b.array_chunks::() - .map(|b| f32x64::from_slice(b)), - ) - .zip(c.array_chunks_mut::()) - .for_each(|((a, b), c)| { - a.add(b).copy_to_slice(c); - }); - - let remain = a.len() - a.len() % CHUNK_SIZE; - a[remain..] - .iter() - .zip(&b[remain..]) - .zip(&mut c[remain..]) - .for_each(|((a, b), c)| { - *c = a + b; - }); - } - #[inline(always)] - pub fn sub(a: &[f32], b: &[f32], c: &mut [f32]) { - assert!(a.len() == b.len() && b.len() == c.len()); - - a.array_chunks::() - .map(|a| f32x64::from_slice(a)) - .zip( - b.array_chunks::() - .map(|b| f32x64::from_slice(b)), - ) - .zip(c.array_chunks_mut::()) - .for_each(|((a, b), c)| { - a.sub(b).copy_to_slice(c); - }); - - let remain = a.len() - a.len() % CHUNK_SIZE; - a[remain..] - .iter() - .zip(&b[remain..]) - .zip(&mut c[remain..]) - .for_each(|((a, b), c)| { - *c = a - b; - }); - } - #[inline(always)] - pub fn mul(a: &[f32], b: &[f32], c: &mut [f32]) { - assert!(a.len() == b.len() && b.len() == c.len()); - - a.array_chunks::() - .map(|a| f32x64::from_slice(a)) - .zip( - b.array_chunks::() - .map(|b| f32x64::from_slice(b)), - ) - .zip(c.array_chunks_mut::()) - .for_each(|((a, b), c)| { - a.mul(b).copy_to_slice(c); - }); - - let remain = a.len() - a.len() % CHUNK_SIZE; - a[remain..] - .iter() - .zip(&b[remain..]) - .zip(&mut c[remain..]) - .for_each(|((a, b), c)| { - *c = a * b; - }); - } - #[inline(always)] - pub fn div(a: &[f32], b: &[f32], c: &mut [f32]) { - assert!(a.len() == b.len() && b.len() == c.len()); - - a.array_chunks::() - .map(|a| f32x64::from_slice(a)) - .zip( - b.array_chunks::() - .map(|b| f32x64::from_slice(b)), - ) - .zip(c.array_chunks_mut::()) - .for_each(|((a, b), c)| { - a.div(b).copy_to_slice(c); - }); - - let remain = a.len() - a.len() % CHUNK_SIZE; - a[remain..] - .iter() - .zip(&b[remain..]) - .zip(&mut c[remain..]) - .for_each(|((a, b), c)| { - *c = a / b; - }); - } -} - -pub mod reduce_ops { - use super::*; - use std::{ops::AddAssign, simd::num::SimdFloat}; - #[inline(always)] - pub fn sum(a: &[f32], b: &mut [f32]) { - assert!(b.len() == 1); - - let mut acc = f32x64::splat(0.0); - a.array_chunks::() - .map(|a| f32x64::from_array(*a)) - .for_each(|a| { - acc.add_assign(a); - }); - - let remain = a.len() - a.len() % CHUNK_SIZE; - b[0] = acc.reduce_sum() + a[remain..].iter().sum::(); - } - #[inline(always)] - pub fn max(a: &[f32], b: &mut [f32]) { - assert!(b.len() == 1); - - let mut acc = f32x64::splat(f32::NEG_INFINITY); - a.array_chunks::() - .map(|a| f32x64::from_array(*a)) - .for_each(|a| { - acc = acc.simd_max(a); - }); - - let remain = a.len() - a.len() % CHUNK_SIZE; - b[0] = a[remain..].iter().copied().fold(acc.reduce_max(), f32::max); - } -} diff --git a/src/storage/mod.rs b/src/storage/mod.rs deleted file mode 100644 index 3f8e0b7..0000000 --- a/src/storage/mod.rs +++ /dev/null @@ -1,217 +0,0 @@ -use crate::device::Device; -mod cpu_backend; - -#[cfg(test)] -mod tests; - -#[derive(Clone, PartialEq)] -pub struct Storage { - data: Vec, - data_size: usize, - device: Device, - // ref_count: i32, -} - -impl Storage { - pub fn new(value: f32, size: usize, device: &str) -> Self { - Storage { - data: vec![value; size], - data_size: size, - device: Device::from_str(device), - // ref_count: 1, - } - } - pub fn from_vec(vec: Vec, device: &str) -> Self { - Storage { - data_size: vec.len(), - data: vec, - device: Device::from_str(device), - // ref_count: 1, - } - } - pub fn to_vec(&self) -> Vec { - self.data.clone() - } - // #[inline(always)] - // pub fn incref(&mut self) { - // assert!(0 < self.ref_count); - // self.ref_count += 1; - // } - // #[inline(always)] - // pub fn decref(&mut self) { - // assert!(0 < self.ref_count); - // self.ref_count -= 1; - // if self.ref_count == 0 { - // // You might wonder why we manually drop the memory here, - // // instead of letting the Rust compiler handle it. - // // The reason is that this is the library code binding to the Python interpreter. - // // The Rust compiler does not know when the Python interpreter will release the memory. - // // Therefore, we need to manually drop the memory when the reference count is zero. - // self.data.clear(); - // self.data.shrink_to_fit(); - // } - // } - pub fn get_items(&self, idx: usize, size: usize) -> &[f32] { - assert!(0 < size); - assert!(idx + size <= self.data_size); - self.data[idx..idx + size].as_ref() - } - pub fn get_items_mut(&mut self, idx: usize, size: usize) -> &mut [f32] { - assert!(0 < size); - assert!(idx + size <= self.data_size); - self.data[idx..idx + size].as_mut() - } -} - -impl Storage { - #[inline(always)] - pub fn neg(a: &Storage, b: &mut Storage, idx_a: usize, idx_b: usize, size: usize) { - assert!(a.device == b.device); - match a.device { - Device::Cpu => { - cpu_backend::unary_ops::neg(a.get_items(idx_a, size), b.get_items_mut(idx_b, size)) - } - Device::Metal => todo!(), - Device::Cuda => todo!(), - } - } - #[inline(always)] - pub fn sqrt(a: &Storage, b: &mut Storage, idx_a: usize, idx_b: usize, size: usize) { - assert!(a.device == b.device); - match a.device { - Device::Cpu => { - cpu_backend::unary_ops::sqrt(a.get_items(idx_a, size), b.get_items_mut(idx_b, size)) - } - Device::Metal => todo!(), - Device::Cuda => todo!(), - } - } - #[inline(always)] - pub fn exp(a: &Storage, b: &mut Storage, idx_a: usize, idx_b: usize, size: usize) { - assert!(a.device == b.device); - match a.device { - Device::Cpu => { - cpu_backend::unary_ops::exp(a.get_items(idx_a, size), b.get_items_mut(idx_b, size)) - } - Device::Metal => todo!(), - Device::Cuda => todo!(), - } - } - #[inline(always)] - pub fn log(a: &Storage, b: &mut Storage, idx_a: usize, idx_b: usize, size: usize) { - assert!(a.device == b.device); - match a.device { - Device::Cpu => { - cpu_backend::unary_ops::log(a.get_items(idx_a, size), b.get_items_mut(idx_b, size)) - } - Device::Metal => todo!(), - Device::Cuda => todo!(), - } - } - #[inline(always)] - pub fn add( - a: &Storage, - b: &Storage, - c: &mut Storage, - idx_a: usize, - idx_b: usize, - idx_c: usize, - size: usize, - ) { - assert!(a.device == b.device && b.device == c.device); - match a.device { - Device::Cpu => cpu_backend::binary_ops::add( - a.get_items(idx_a, size), - b.get_items(idx_b, size), - c.get_items_mut(idx_c, size), - ), - Device::Metal => todo!(), - Device::Cuda => todo!(), - } - } - #[inline(always)] - pub fn sub( - a: &Storage, - b: &Storage, - c: &mut Storage, - idx_a: usize, - idx_b: usize, - idx_c: usize, - size: usize, - ) { - assert!(a.device == b.device && b.device == c.device); - match a.device { - Device::Cpu => cpu_backend::binary_ops::sub( - a.get_items(idx_a, size), - b.get_items(idx_b, size), - c.get_items_mut(idx_c, size), - ), - Device::Metal => todo!(), - Device::Cuda => todo!(), - } - } - #[inline(always)] - pub fn mul( - a: &Storage, - b: &Storage, - c: &mut Storage, - idx_a: usize, - idx_b: usize, - idx_c: usize, - size: usize, - ) { - assert!(a.device == b.device && b.device == c.device); - match a.device { - Device::Cpu => cpu_backend::binary_ops::mul( - a.get_items(idx_a, size), - b.get_items(idx_b, size), - c.get_items_mut(idx_c, size), - ), - Device::Metal => todo!(), - Device::Cuda => todo!(), - } - } - #[inline(always)] - pub fn div( - a: &Storage, - b: &Storage, - c: &mut Storage, - idx_a: usize, - idx_b: usize, - idx_c: usize, - size: usize, - ) { - assert!(a.device == b.device && b.device == c.device); - match a.device { - Device::Cpu => cpu_backend::binary_ops::div( - a.get_items(idx_a, size), - b.get_items(idx_b, size), - c.get_items_mut(idx_c, size), - ), - Device::Metal => todo!(), - Device::Cuda => todo!(), - } - } - #[inline(always)] - pub fn sum(a: &Storage, b: &mut Storage, idx_a: usize, idx_b: usize, size: usize) { - assert!(a.device == b.device); - match a.device { - Device::Cpu => { - cpu_backend::reduce_ops::sum(a.get_items(idx_a, size), b.get_items_mut(idx_b, 1)) - } - Device::Metal => todo!(), - Device::Cuda => todo!(), - } - } - #[inline(always)] - pub fn max(a: &Storage, b: &mut Storage, idx_a: usize, idx_b: usize, size: usize) { - assert!(a.device == b.device); - match a.device { - Device::Cpu => { - cpu_backend::reduce_ops::max(a.get_items(idx_a, size), b.get_items_mut(idx_b, 1)) - } - Device::Metal => todo!(), - Device::Cuda => todo!(), - } - } -} diff --git a/src/storage/tests.rs b/src/storage/tests.rs deleted file mode 100644 index e6afae8..0000000 --- a/src/storage/tests.rs +++ /dev/null @@ -1,242 +0,0 @@ -use crate::storage::Storage; -use rand::random; - -const DEVICE: &str = "cpu"; -const MAX_VEC_SIZE: usize = 4000; -const DEFAULT_TEST_COUNT: usize = 10; - -#[test] -fn test_storage_neg() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let mut y = vec![0.0; vec_size]; - for i in 0..size { - y[idx_2 + i] = -x[idx_1 + i]; - } - - let a = Storage::from_vec(x.clone(), DEVICE); - let mut b = Storage::new(0.0, vec_size, DEVICE); - Storage::neg(&a, &mut b, idx_1, idx_2, size); - - assert!(y.as_slice() == b.get_items(0, vec_size)); - } -} -#[test] -fn test_storage_sqrt() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let mut y = vec![0.0; vec_size]; - for i in 0..size { - y[idx_2 + i] = x[idx_1 + i].sqrt(); - } - - let a = Storage::from_vec(x.clone(), DEVICE); - let mut b = Storage::new(0.0, vec_size, DEVICE); - Storage::sqrt(&a, &mut b, idx_1, idx_2, size); - - assert!(y.as_slice() == b.get_items(0, vec_size)); - } -} -#[test] -fn test_storage_exp() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let mut y = vec![0.0; vec_size]; - for i in 0..size { - y[idx_2 + i] = x[idx_1 + i].exp(); - } - - let a = Storage::from_vec(x.clone(), DEVICE); - let mut b = Storage::new(0.0, vec_size, DEVICE); - Storage::exp(&a, &mut b, idx_1, idx_2, size); - - assert!(y.as_slice() == b.get_items(0, vec_size)); - } -} -#[test] -fn test_storage_log() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let mut y = vec![0.0; vec_size]; - for i in 0..size { - y[idx_2 + i] = x[idx_1 + i].ln(); - } - - let a = Storage::from_vec(x.clone(), DEVICE); - let mut b = Storage::new(0.0, vec_size, DEVICE); - Storage::log(&a, &mut b, idx_1, idx_2, size); - - assert!(y.as_slice() == b.get_items(0, vec_size)); - } -} -#[test] -fn test_storage_add() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let idx_3 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2).max(idx_3)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let y = (0..vec_size).map(|_| random::()).collect::>(); - let mut z = vec![0.0; vec_size]; - for i in 0..size { - z[idx_3 + i] = x[idx_1 + i] + y[idx_2 + i]; - } - - let a = Storage::from_vec(x.clone(), DEVICE); - let b = Storage::from_vec(y.clone(), DEVICE); - let mut c = Storage::new(0.0, vec_size, DEVICE); - Storage::add(&a, &b, &mut c, idx_1, idx_2, idx_3, size); - - assert!(y.as_slice() == b.get_items(0, vec_size)); - } -} -#[test] -fn test_storage_sub() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let idx_3 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2).max(idx_3)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let y = (0..vec_size).map(|_| random::()).collect::>(); - let mut z = vec![0.0; vec_size]; - for i in 0..size { - z[idx_3 + i] = x[idx_1 + i] - y[idx_2 + i]; - } - - let a = Storage::from_vec(x.clone(), DEVICE); - let b = Storage::from_vec(y.clone(), DEVICE); - let mut c = Storage::new(0.0, vec_size, DEVICE); - Storage::sub(&a, &b, &mut c, idx_1, idx_2, idx_3, size); - - assert!(y.as_slice() == b.get_items(0, vec_size)); - } -} -#[test] -fn test_storage_mul() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let idx_3 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2).max(idx_3)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let y = (0..vec_size).map(|_| random::()).collect::>(); - let mut z = vec![0.0; vec_size]; - for i in 0..size { - z[idx_3 + i] = x[idx_1 + i] * y[idx_2 + i]; - } - - let a = Storage::from_vec(x.clone(), DEVICE); - let b = Storage::from_vec(y.clone(), DEVICE); - let mut c = Storage::new(0.0, vec_size, DEVICE); - Storage::mul(&a, &b, &mut c, idx_1, idx_2, idx_3, size); - - assert!(y.as_slice() == b.get_items(0, vec_size)); - } -} -#[test] -fn test_storage_div() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let idx_3 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2).max(idx_3)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let y = (0..vec_size).map(|_| random::()).collect::>(); - let mut z = vec![0.0; vec_size]; - for i in 0..size { - z[idx_3 + i] = x[idx_1 + i] / y[idx_2 + i]; - } - - let a = Storage::from_vec(x.clone(), DEVICE); - let b = Storage::from_vec(y.clone(), DEVICE); - let mut c = Storage::new(0.0, vec_size, DEVICE); - Storage::div(&a, &b, &mut c, idx_1, idx_2, idx_3, size); - - assert!(y.as_slice() == b.get_items(0, vec_size)); - } -} -#[test] -fn test_storage_sum() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let size = random::() % (vec_size - idx_1) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let mut y = vec![0.0; vec_size]; - for i in 0..size { - y[idx_2] += x[idx_1 + i]; - } - - let a = Storage::from_vec(x.clone(), DEVICE); - let mut b = Storage::new(0.0, vec_size, DEVICE); - Storage::sum(&a, &mut b, idx_1, idx_2, size); - - assert!((y[idx_2] - b.get_items(idx_2, 1)[0]).abs() < 1e-5 * y[idx_2].abs()); - } -} -#[test] -fn test_storage_max() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let size = random::() % (vec_size - idx_1) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let mut y = vec![0.0; vec_size]; - y[idx_2] = f32::NEG_INFINITY; - for i in 0..size { - if y[idx_2] < x[idx_1 + i] { - y[idx_2] = x[idx_1 + i]; - } - } - - let a = Storage::from_vec(x.clone(), DEVICE); - let mut b = Storage::new(0.0, vec_size, DEVICE); - Storage::max(&a, &mut b, idx_1, idx_2, size); - - assert!(y[idx_2] == b.get_items(idx_2, 1)[0]); - } -} diff --git a/src/storage_ptr/mod.rs b/src/storage_ptr/mod.rs deleted file mode 100644 index f6ce4f6..0000000 --- a/src/storage_ptr/mod.rs +++ /dev/null @@ -1,193 +0,0 @@ -use std::sync::{Arc, Mutex}; - -use crate::storage::Storage; - -#[cfg(test)] -mod tests; - -#[pyo3::pyclass] -#[derive(Clone)] -pub struct StoragePtr { - storage: Arc>, // Arc allows shared ownership, Mutex ensures safe mutable access -} - -impl StoragePtr { - fn new(storage: Storage) -> StoragePtr { - StoragePtr { - storage: Arc::new(Mutex::new(storage)), - } - } - fn get_storage(&self) -> Arc> { - Arc::clone(&self.storage) - } -} - -#[pyo3::pymethods] -impl StoragePtr { - #[staticmethod] - fn full(value: f32, size: usize, device: &str) -> StoragePtr { - StoragePtr::new(Storage::new(value, size, device)) - } - #[staticmethod] - fn from_vec(vec: Vec, device: &str) -> StoragePtr { - StoragePtr::new(Storage::from_vec(vec, device)) - } - #[staticmethod] - fn neg(a: &StoragePtr, b: &StoragePtr, idx_a: usize, idx_b: usize, size: usize) { - let binding = a.get_storage(); - let a_storage = binding.lock().unwrap(); - let binding = b.get_storage(); - let mut b_storage = binding.lock().unwrap(); - Storage::neg(&a_storage, &mut b_storage, idx_a, idx_b, size); - } - #[staticmethod] - fn sqrt(a: &StoragePtr, b: &StoragePtr, idx_a: usize, idx_b: usize, size: usize) { - let binding = a.get_storage(); - let a_storage = binding.lock().unwrap(); - let binding = b.get_storage(); - let mut b_storage = binding.lock().unwrap(); - Storage::sqrt(&a_storage, &mut b_storage, idx_a, idx_b, size); - } - #[staticmethod] - fn exp(a: &StoragePtr, b: &StoragePtr, idx_a: usize, idx_b: usize, size: usize) { - let binding = a.get_storage(); - let a_storage = binding.lock().unwrap(); - let binding = b.get_storage(); - let mut b_storage = binding.lock().unwrap(); - Storage::exp(&a_storage, &mut b_storage, idx_a, idx_b, size); - } - #[staticmethod] - fn log(a: &StoragePtr, b: &StoragePtr, idx_a: usize, idx_b: usize, size: usize) { - let binding = a.get_storage(); - let a_storage = binding.lock().unwrap(); - let binding = b.get_storage(); - let mut b_storage = binding.lock().unwrap(); - Storage::log(&a_storage, &mut b_storage, idx_a, idx_b, size); - } - #[staticmethod] - fn add( - a: &StoragePtr, - b: &StoragePtr, - c: &StoragePtr, - idx_a: usize, - idx_b: usize, - idx_c: usize, - size: usize, - ) { - let binding = a.get_storage(); - let a_storage = binding.lock().unwrap(); - let binding = b.get_storage(); - let b_storage = binding.lock().unwrap(); - let binding = c.get_storage(); - let mut c_storage = binding.lock().unwrap(); - Storage::add( - &a_storage, - &b_storage, - &mut c_storage, - idx_a, - idx_b, - idx_c, - size, - ); - } - #[staticmethod] - fn sub( - a: &StoragePtr, - b: &StoragePtr, - c: &StoragePtr, - idx_a: usize, - idx_b: usize, - idx_c: usize, - size: usize, - ) { - let binding = a.get_storage(); - let a_storage = binding.lock().unwrap(); - let binding = b.get_storage(); - let b_storage = binding.lock().unwrap(); - let binding = c.get_storage(); - let mut c_storage = binding.lock().unwrap(); - Storage::sub( - &a_storage, - &b_storage, - &mut c_storage, - idx_a, - idx_b, - idx_c, - size, - ); - } - #[staticmethod] - fn mul( - a: &StoragePtr, - b: &StoragePtr, - c: &StoragePtr, - idx_a: usize, - idx_b: usize, - idx_c: usize, - size: usize, - ) { - let binding = a.get_storage(); - let a_storage = binding.lock().unwrap(); - let binding = b.get_storage(); - let b_storage = binding.lock().unwrap(); - let binding = c.get_storage(); - let mut c_storage = binding.lock().unwrap(); - Storage::mul( - &a_storage, - &b_storage, - &mut c_storage, - idx_a, - idx_b, - idx_c, - size, - ); - } - #[staticmethod] - fn div( - a: &StoragePtr, - b: &StoragePtr, - c: &StoragePtr, - idx_a: usize, - idx_b: usize, - idx_c: usize, - size: usize, - ) { - let binding = a.get_storage(); - let a_storage = binding.lock().unwrap(); - let binding = b.get_storage(); - let b_storage = binding.lock().unwrap(); - let binding = c.get_storage(); - let mut c_storage = binding.lock().unwrap(); - Storage::div( - &a_storage, - &b_storage, - &mut c_storage, - idx_a, - idx_b, - idx_c, - size, - ); - } - #[staticmethod] - fn sum(a: &StoragePtr, b: &StoragePtr, idx_a: usize, idx_b: usize, size: usize) { - let binding = a.get_storage(); - let a_storage = binding.lock().unwrap(); - let binding = b.get_storage(); - let mut b_storage = binding.lock().unwrap(); - Storage::sum(&a_storage, &mut b_storage, idx_a, idx_b, size); - } - #[staticmethod] - fn max(a: &StoragePtr, b: &StoragePtr, idx_a: usize, idx_b: usize, size: usize) { - let binding = a.get_storage(); - let a_storage = binding.lock().unwrap(); - let binding = b.get_storage(); - let mut b_storage = binding.lock().unwrap(); - Storage::max(&a_storage, &mut b_storage, idx_a, idx_b, size); - } - - fn to_vec(&self) -> Vec { - let binding = self.get_storage(); - let self_storage = binding.lock().unwrap(); - self_storage.to_vec() - } -} diff --git a/src/storage_ptr/tests.rs b/src/storage_ptr/tests.rs deleted file mode 100644 index 9cac129..0000000 --- a/src/storage_ptr/tests.rs +++ /dev/null @@ -1,242 +0,0 @@ -use crate::storage_ptr::StoragePtr; -use rand::random; - -const DEVICE: &str = "cpu"; -const MAX_VEC_SIZE: usize = 4000; -const DEFAULT_TEST_COUNT: usize = 10; - -#[test] -fn test_storage_ptr_neg() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let mut y = vec![0.0; vec_size]; - for i in 0..size { - y[idx_2 + i] = -x[idx_1 + i]; - } - - let a = StoragePtr::from_vec(x.clone(), DEVICE); - let b = StoragePtr::full(0.0, vec_size, DEVICE); - StoragePtr::neg(&a, &b, idx_1, idx_2, size); - - assert_eq!(b.to_vec(), y); - } -} -#[test] -fn test_storage_ptr_sqrt() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let mut y = vec![0.0; vec_size]; - for i in 0..size { - y[idx_2 + i] = x[idx_1 + i].sqrt(); - } - - let a = StoragePtr::from_vec(x.clone(), DEVICE); - let b = StoragePtr::full(0.0, vec_size, DEVICE); - StoragePtr::sqrt(&a, &b, idx_1, idx_2, size); - - assert_eq!(b.to_vec(), y); - } -} -#[test] -fn test_storage_ptr_exp() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let mut y = vec![0.0; vec_size]; - for i in 0..size { - y[idx_2 + i] = x[idx_1 + i].exp(); - } - - let a = StoragePtr::from_vec(x.clone(), DEVICE); - let b = StoragePtr::full(0.0, vec_size, DEVICE); - StoragePtr::exp(&a, &b, idx_1, idx_2, size); - - assert_eq!(b.to_vec(), y); - } -} -#[test] -fn test_storage_ptr_log() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let mut y = vec![0.0; vec_size]; - for i in 0..size { - y[idx_2 + i] = x[idx_1 + i].ln(); - } - - let a = StoragePtr::from_vec(x.clone(), DEVICE); - let b = StoragePtr::full(0.0, vec_size, DEVICE); - StoragePtr::log(&a, &b, idx_1, idx_2, size); - - assert_eq!(b.to_vec(), y); - } -} -#[test] -fn test_storage_ptr_add() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let idx_3 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2).max(idx_3)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let y = (0..vec_size).map(|_| random::()).collect::>(); - let mut z = vec![0.0; vec_size]; - for i in 0..size { - z[idx_3 + i] = x[idx_1 + i] + y[idx_2 + i]; - } - - let a = StoragePtr::from_vec(x.clone(), DEVICE); - let b = StoragePtr::from_vec(y.clone(), DEVICE); - let c = StoragePtr::full(0.0, vec_size, DEVICE); - StoragePtr::add(&a, &b, &c, idx_1, idx_2, idx_3, size); - - assert_eq!(c.to_vec(), z); - } -} -#[test] -fn test_storage_ptr_sub() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let idx_3 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2).max(idx_3)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let y = (0..vec_size).map(|_| random::()).collect::>(); - let mut z = vec![0.0; vec_size]; - for i in 0..size { - z[idx_3 + i] = x[idx_1 + i] - y[idx_2 + i]; - } - - let a = StoragePtr::from_vec(x.clone(), DEVICE); - let b = StoragePtr::from_vec(y.clone(), DEVICE); - let c = StoragePtr::full(0.0, vec_size, DEVICE); - StoragePtr::sub(&a, &b, &c, idx_1, idx_2, idx_3, size); - - assert_eq!(c.to_vec(), z); - } -} -#[test] -fn test_storage_ptr_mul() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let idx_3 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2).max(idx_3)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let y = (0..vec_size).map(|_| random::()).collect::>(); - let mut z = vec![0.0; vec_size]; - for i in 0..size { - z[idx_3 + i] = x[idx_1 + i] * y[idx_2 + i]; - } - - let a = StoragePtr::from_vec(x.clone(), DEVICE); - let b = StoragePtr::from_vec(y.clone(), DEVICE); - let c = StoragePtr::full(0.0, vec_size, DEVICE); - StoragePtr::mul(&a, &b, &c, idx_1, idx_2, idx_3, size); - - assert_eq!(c.to_vec(), z); - } -} -#[test] -fn test_storage_ptr_div() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let idx_3 = random::() % vec_size; - let size = random::() % (vec_size - idx_1.max(idx_2).max(idx_3)) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let y = (0..vec_size).map(|_| random::()).collect::>(); - let mut z = vec![0.0; vec_size]; - for i in 0..size { - z[idx_3 + i] = x[idx_1 + i] / y[idx_2 + i]; - } - - let a = StoragePtr::from_vec(x.clone(), DEVICE); - let b = StoragePtr::from_vec(y.clone(), DEVICE); - let c = StoragePtr::full(0.0, vec_size, DEVICE); - StoragePtr::div(&a, &b, &c, idx_1, idx_2, idx_3, size); - - assert_eq!(c.to_vec(), z); - } -} -#[test] -fn test_storage_ptr_sum() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let size = random::() % (vec_size - idx_1) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let mut y = vec![0.0; vec_size]; - for i in 0..size { - y[idx_2] += x[idx_1 + i]; - } - - let a = StoragePtr::from_vec(x.clone(), DEVICE); - let b = StoragePtr::full(0.0, vec_size, DEVICE); - StoragePtr::sum(&a, &b, idx_1, idx_2, size); - - assert!((y[idx_2] - b.to_vec()[idx_2]).abs() < 1e-5 * y[idx_2].abs()); - } -} -#[test] -fn test_storage_ptr_max() { - for _ in 0..DEFAULT_TEST_COUNT { - let vec_size = random::() % MAX_VEC_SIZE + 1; - - let idx_1 = random::() % vec_size; - let idx_2 = random::() % vec_size; - let size = random::() % (vec_size - idx_1) + 1; - - let x = (0..vec_size).map(|_| random::()).collect::>(); - let mut y = vec![0.0; vec_size]; - y[idx_2] = f32::NEG_INFINITY; - for i in 0..size { - if y[idx_2] < x[idx_1 + i] { - y[idx_2] = x[idx_1 + i]; - } - } - - let a = StoragePtr::from_vec(x.clone(), DEVICE); - let b = StoragePtr::full(0.0, vec_size, DEVICE); - StoragePtr::max(&a, &b, idx_1, idx_2, size); - - assert_eq!(b.to_vec(), y); - } -} diff --git a/src/tests/backend.rs b/src/tests/backend.rs new file mode 100644 index 0000000..a115ff8 --- /dev/null +++ b/src/tests/backend.rs @@ -0,0 +1,85 @@ +use std::sync::Arc; + +use crate::backend::{Backend, BinaryOp, CpuBackend, UnaryOp}; +use crate::core::{storage::StorageInner, view::View}; +use crate::device::Device; + +fn vec_view(v: Vec) -> View { + let inner = Arc::new(StorageInner::from_vec(v, Device::Cpu)); + View::from_inner_1d(inner) +} + +#[test] +fn unary_neg_matches_scalar() { + let a = vec_view(vec![1.0, -2.0, 3.0, -4.5, 0.25]); + let be = CpuBackend; + let out = be.unary(UnaryOp::Neg, &a).unwrap(); + let actual = out.inner.as_slice(out.offset, out.numel()); + let expect: Vec = a + .inner + .as_slice(a.offset, a.numel()) + .iter() + .map(|x| -*x) + .collect(); + assert_eq!(actual, expect.as_slice()); +} + +#[test] +fn unary_sqrt_matches_scalar() { + let a = vec_view(vec![0.0, 0.25, 1.0, 4.0, 9.0, 16.0, 2.25]); + let be = CpuBackend; + let out = be.unary(UnaryOp::Sqrt, &a).unwrap(); + let actual = out.inner.as_slice(out.offset, out.numel()); + let expect: Vec = a + .inner + .as_slice(a.offset, a.numel()) + .iter() + .map(|x| x.sqrt()) + .collect(); + for (got, exp) in actual.iter().zip(expect.iter()) { + assert!((got - exp).abs() < 1e-6); + } +} + +#[test] +fn binary_add_matches_scalar() { + let a = vec_view((0..130).map(|i| i as f32 * 0.5).collect()); // span SIMD + remainder + let b = vec_view((0..130).map(|i| i as f32 * -0.25).collect()); + let be = CpuBackend; + let out = be.binary(BinaryOp::Add, &a, &b).unwrap(); + let actual = out.inner.as_slice(out.offset, out.numel()); + let a_s = a.inner.as_slice(a.offset, a.numel()); + let b_s = b.inner.as_slice(b.offset, b.numel()); + for i in 0..a.numel() { + assert!((actual[i] - (a_s[i] + b_s[i])).abs() < 1e-6); + } +} + +#[test] +fn binary_shape_mismatch_error() { + let a = vec_view(vec![1.0, 2.0, 3.0]); + let b = vec_view(vec![4.0, 5.0]); + let be = CpuBackend; + let err = be.binary(BinaryOp::Add, &a, &b).unwrap_err(); + assert!(matches!(err, crate::backend::BackendError::ShapeMismatch)); +} + +#[test] +fn unary_not_contiguous_error() { + let a = vec_view((0..12).map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); + let a_nc = a.permute(&[1, 0]); + assert!(!a_nc.is_contiguous()); + let be = CpuBackend; + let err = be.unary(UnaryOp::Neg, &a_nc).unwrap_err(); + assert!(matches!(err, crate::backend::BackendError::NotContiguous)); +} + +#[test] +fn binary_not_contiguous_error() { + let a = vec_view((0..12).map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); + let b = vec_view((0..12).rev().map(|i| i as f32).collect()).reshape_contiguous(&[3, 4]); + let a_nc = a.permute(&[1, 0]); + let be = CpuBackend; + let err = be.binary(BinaryOp::Add, &a_nc, &b).unwrap_err(); + assert!(matches!(err, crate::backend::BackendError::NotContiguous)); +} diff --git a/src/tests/device.rs b/src/tests/device.rs new file mode 100644 index 0000000..9708557 --- /dev/null +++ b/src/tests/device.rs @@ -0,0 +1,16 @@ +use crate::device::Device; + +#[test] +fn device_from_str_variants() { + assert!(matches!(Device::from_str("cpu"), Device::Cpu)); + assert!(matches!(Device::from_str("metal"), Device::Metal)); + assert!(matches!(Device::from_str("cuda"), Device::Cuda)); +} + +#[test] +fn device_from_str_invalid_panics() { + let res = std::panic::catch_unwind(|| { + let _ = Device::from_str("weird"); + }); + assert!(res.is_err()); +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs new file mode 100644 index 0000000..d970109 --- /dev/null +++ b/src/tests/mod.rs @@ -0,0 +1,3 @@ +mod backend; +mod device; +mod view; diff --git a/src/tests/view.rs b/src/tests/view.rs new file mode 100644 index 0000000..a1fe64c --- /dev/null +++ b/src/tests/view.rs @@ -0,0 +1,77 @@ +use std::sync::Arc; + +use crate::core::{storage::StorageInner, view::View}; +use crate::device::Device; + +#[test] +fn view_1d_contiguous_and_numel() { + let inner = Arc::new(StorageInner::new_full(0.0, 10, Device::Cpu)); + let v = View::from_inner_1d(inner); + assert!(v.is_contiguous()); + assert_eq!(v.numel(), 10); + assert_eq!(v.shape, vec![10]); + assert_eq!(v.strides, vec![1]); +} + +#[test] +fn view_slice_1d() { + let inner = Arc::new(StorageInner::new_full(0.0, 10, Device::Cpu)); + let v = View::from_inner_1d(inner); + let s = v.slice_1d(3, 4); + assert!(s.is_contiguous()); + assert_eq!(s.offset, 3); + assert_eq!(s.shape, vec![4]); + assert_eq!(s.strides, vec![1]); +} + +#[test] +fn view_reshape_contiguous() { + let inner = Arc::new(StorageInner::new_full(0.0, 12, Device::Cpu)); + let v = View::from_inner_1d(inner); + let r = v.reshape_contiguous(&[3, 4]); + assert!(r.is_contiguous()); + assert_eq!(r.shape, vec![3, 4]); + assert_eq!(r.strides, vec![4, 1]); + assert_eq!(r.offset, 0); +} + +#[test] +fn view_permute_non_contiguous() { + let inner = Arc::new(StorageInner::new_full(0.0, 12, Device::Cpu)); + let v = View::from_inner_1d(inner).reshape_contiguous(&[3, 4]); + let p = v.permute(&[1, 0]); + assert_eq!(p.shape, vec![4, 3]); + assert_eq!(p.strides, vec![1, 4]); + assert!(!p.is_contiguous()); +} + +#[test] +fn view_expand_broadcast() { + let inner = Arc::new(StorageInner::new_full(0.0, 3, Device::Cpu)); + let v = View::from_inner_1d(inner).reshape_contiguous(&[1, 3]); + let e = v.expand(&[2, 3]); + assert_eq!(e.shape, vec![2, 3]); + assert_eq!(e.strides, vec![0, 1]); // broadcasted axis gets stride 0 + assert!(!e.is_contiguous()); +} + +#[test] +fn view_expand_incompatible_panics() { + let inner = Arc::new(StorageInner::new_full(0.0, 3, Device::Cpu)); + let v = View::from_inner_1d(inner); + let res = std::panic::catch_unwind(|| { + let _ = v.expand(&[4]); + }); + assert!(res.is_err()); +} + +#[test] +fn view_reshape_requires_contiguous_panics() { + let inner = Arc::new(StorageInner::new_full(0.0, 12, Device::Cpu)); + let v = View::from_inner_1d(inner).reshape_contiguous(&[3, 4]); + let p = v.permute(&[1, 0]); + let res = std::panic::catch_unwind(|| { + let _ = p.reshape_contiguous(&[2, 6]); + }); + assert!(res.is_err()); +} diff --git a/tests/test_storage_ptr.py b/tests/test_storage_ptr.py deleted file mode 100644 index 2d5c2a9..0000000 --- a/tests/test_storage_ptr.py +++ /dev/null @@ -1,211 +0,0 @@ -import unittest -import numpy as np -from cranberry import StoragePtr - -np.random.seed(1337) - -MAX_VEC_SIZE = 4000 -DEFAULT_TEST_COUNT = 10 -DEVICE = "cpu" - -rtol, atol = 1e-4, 1e-4 - - -class TestStoragePtr(unittest.TestCase): - def test_storage_ptr_neg(self): - for _ in range(DEFAULT_TEST_COUNT): - vec_size = np.random.randint(1, MAX_VEC_SIZE + 1) - - idx_1 = np.random.randint(0, vec_size) - idx_2 = np.random.randint(0, vec_size) - size = np.random.randint(1, min(vec_size - idx_1, vec_size - idx_2) + 1) - - x = np.random.rand(vec_size).astype(np.float32) - y = np.zeros(vec_size, dtype=np.float32) - y[idx_2 : idx_2 + size] = -x[idx_1 : idx_1 + size] - - a = StoragePtr.from_vec(x, DEVICE) - b = StoragePtr.full(0.0, vec_size, DEVICE) - StoragePtr.neg(a, b, idx_1, idx_2, size) - - np.testing.assert_allclose(y, b.to_vec(), rtol, atol) - - def test_storage_ptr_sqrt(self): - for _ in range(DEFAULT_TEST_COUNT): - vec_size = np.random.randint(1, MAX_VEC_SIZE + 1) - - idx_1 = np.random.randint(0, vec_size) - idx_2 = np.random.randint(0, vec_size) - size = np.random.randint(1, min(vec_size - idx_1, vec_size - idx_2) + 1) - - x = np.random.rand(vec_size).astype(np.float32) - y = np.zeros(vec_size, dtype=np.float32) - y[idx_2 : idx_2 + size] = np.sqrt(x[idx_1 : idx_1 + size]) - - a = StoragePtr.from_vec(x, DEVICE) - b = StoragePtr.full(0.0, vec_size, DEVICE) - StoragePtr.sqrt(a, b, idx_1, idx_2, size) - - np.testing.assert_allclose(y, b.to_vec(), rtol, atol) - - def test_storage_ptr_exp(self): - for _ in range(DEFAULT_TEST_COUNT): - vec_size = np.random.randint(1, MAX_VEC_SIZE + 1) - - idx_1 = np.random.randint(0, vec_size) - idx_2 = np.random.randint(0, vec_size) - size = np.random.randint(1, min(vec_size - idx_1, vec_size - idx_2) + 1) - - x = np.random.rand(vec_size).astype(np.float32) - y = np.zeros(vec_size, dtype=np.float32) - y[idx_2 : idx_2 + size] = np.exp(x[idx_1 : idx_1 + size]) - - a = StoragePtr.from_vec(x, DEVICE) - b = StoragePtr.full(0.0, vec_size, DEVICE) - StoragePtr.exp(a, b, idx_1, idx_2, size) - - np.testing.assert_allclose(y, b.to_vec(), rtol, atol) - - def test_storage_ptr_log(self): - for _ in range(DEFAULT_TEST_COUNT): - vec_size = np.random.randint(1, MAX_VEC_SIZE + 1) - - idx_1 = np.random.randint(0, vec_size) - idx_2 = np.random.randint(0, vec_size) - size = np.random.randint(1, min(vec_size - idx_1, vec_size - idx_2) + 1) - - x = np.random.rand(vec_size).astype(np.float32) - y = np.zeros(vec_size, dtype=np.float32) - y[idx_2 : idx_2 + size] = np.log(x[idx_1 : idx_1 + size]) - - a = StoragePtr.from_vec(x, DEVICE) - b = StoragePtr.full(0.0, vec_size, DEVICE) - StoragePtr.log(a, b, idx_1, idx_2, size) - - np.testing.assert_allclose(y, b.to_vec(), rtol, atol) - - def test_storage_ptr_add(self): - for _ in range(DEFAULT_TEST_COUNT): - vec_size = np.random.randint(1, MAX_VEC_SIZE + 1) - - idx_1 = np.random.randint(0, vec_size) - idx_2 = np.random.randint(0, vec_size) - idx_3 = np.random.randint(0, vec_size) - - size = np.random.randint(1, min(vec_size - idx_1, vec_size - idx_2, vec_size - idx_3) + 1) - - x = np.random.rand(vec_size).astype(np.float32) - y = np.random.rand(vec_size).astype(np.float32) - z = np.zeros(vec_size, dtype=np.float32) - z[idx_3 : idx_3 + size] = x[idx_1 : idx_1 + size] + y[idx_2 : idx_2 + size] - - a = StoragePtr.from_vec(x, DEVICE) - b = StoragePtr.from_vec(y, DEVICE) - c = StoragePtr.full(0.0, vec_size, DEVICE) - StoragePtr.add(a, b, c, idx_1, idx_2, idx_3, size) - - np.testing.assert_allclose(z, c.to_vec(), rtol, atol) - - def test_storage_ptr_sub(self): - for _ in range(DEFAULT_TEST_COUNT): - vec_size = np.random.randint(1, MAX_VEC_SIZE + 1) - - idx_1 = np.random.randint(0, vec_size) - idx_2 = np.random.randint(0, vec_size) - idx_3 = np.random.randint(0, vec_size) - - size = np.random.randint(1, min(vec_size - idx_1, vec_size - idx_2, vec_size - idx_3) + 1) - - x = np.random.rand(vec_size).astype(np.float32) - y = np.random.rand(vec_size).astype(np.float32) - z = np.zeros(vec_size, dtype=np.float32) - z[idx_3 : idx_3 + size] = x[idx_1 : idx_1 + size] - y[idx_2 : idx_2 + size] - - a = StoragePtr.from_vec(x, DEVICE) - b = StoragePtr.from_vec(y, DEVICE) - c = StoragePtr.full(0.0, vec_size, DEVICE) - StoragePtr.sub(a, b, c, idx_1, idx_2, idx_3, size) - - np.testing.assert_allclose(z, c.to_vec(), rtol, atol) - - def test_storage_ptr_mul(self): - for _ in range(DEFAULT_TEST_COUNT): - vec_size = np.random.randint(1, MAX_VEC_SIZE + 1) - - idx_1 = np.random.randint(0, vec_size) - idx_2 = np.random.randint(0, vec_size) - idx_3 = np.random.randint(0, vec_size) - - size = np.random.randint(1, min(vec_size - idx_1, vec_size - idx_2, vec_size - idx_3) + 1) - - x = np.random.rand(vec_size).astype(np.float32) - y = np.random.rand(vec_size).astype(np.float32) - z = np.zeros(vec_size, dtype=np.float32) - z[idx_3 : idx_3 + size] = x[idx_1 : idx_1 + size] * y[idx_2 : idx_2 + size] - - a = StoragePtr.from_vec(x, DEVICE) - b = StoragePtr.from_vec(y, DEVICE) - c = StoragePtr.full(0.0, vec_size, DEVICE) - StoragePtr.mul(a, b, c, idx_1, idx_2, idx_3, size) - - def test_storage_ptr_div(self): - for _ in range(DEFAULT_TEST_COUNT): - vec_size = np.random.randint(1, MAX_VEC_SIZE + 1) - - idx_1 = np.random.randint(0, vec_size) - idx_2 = np.random.randint(0, vec_size) - idx_3 = np.random.randint(0, vec_size) - - size = np.random.randint(1, min(vec_size - idx_1, vec_size - idx_2, vec_size - idx_3) + 1) - - x = np.random.rand(vec_size).astype(np.float32) - y = np.random.rand(vec_size).astype(np.float32) - z = np.zeros(vec_size, dtype=np.float32) - z[idx_3 : idx_3 + size] = x[idx_1 : idx_1 + size] / y[idx_2 : idx_2 + size] - - a = StoragePtr.from_vec(x, DEVICE) - b = StoragePtr.from_vec(y, DEVICE) - c = StoragePtr.full(0.0, vec_size, DEVICE) - StoragePtr.div(a, b, c, idx_1, idx_2, idx_3, size) - - np.testing.assert_allclose(z, c.to_vec(), rtol, atol) - - def test_storage_ptr_sum(self): - for _ in range(DEFAULT_TEST_COUNT): - vec_size = np.random.randint(1, MAX_VEC_SIZE + 1) - - idx_1 = np.random.randint(0, vec_size) - idx_2 = np.random.randint(0, vec_size) - size = np.random.randint(1, vec_size - idx_1 + 1) - - x = np.random.rand(vec_size).astype(np.float32) - y = np.zeros(vec_size, dtype=np.float32) - y[idx_2] = np.sum(x[idx_1 : idx_1 + size]) - - a = StoragePtr.from_vec(x, DEVICE) - b = StoragePtr.full(0.0, vec_size, DEVICE) - StoragePtr.sum(a, b, idx_1, idx_2, size) - - np.testing.assert_allclose(y, b.to_vec(), rtol, atol) - - def test_storage_ptr_max(self): - for _ in range(DEFAULT_TEST_COUNT): - vec_size = np.random.randint(1, MAX_VEC_SIZE + 1) - - idx_1 = np.random.randint(0, vec_size) - idx_2 = np.random.randint(0, vec_size) - size = np.random.randint(1, vec_size - idx_1 + 1) - - x = np.random.rand(vec_size).astype(np.float32) - y = np.zeros(vec_size, dtype=np.float32) - y[idx_2] = np.max(x[idx_1 : idx_1 + size]) - - a = StoragePtr.from_vec(x, DEVICE) - b = StoragePtr.full(0.0, vec_size, DEVICE) - StoragePtr.max(a, b, idx_1, idx_2, size) - - np.testing.assert_allclose(y, b.to_vec(), rtol, atol) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_storageview.py b/tests/test_storageview.py new file mode 100644 index 0000000..46541c4 --- /dev/null +++ b/tests/test_storageview.py @@ -0,0 +1,353 @@ +import numpy as np +import unittest + +from cranberry import StorageView + + +np.random.seed(1337) + +N, M = 54, 29 # keep in sync with test_tensor.py +A_np = np.random.randn(N * M).astype(np.float32) +B_np = np.random.randn(N * M).astype(np.float32) + +rtol, atol = 1e-5, 1e-5 + + +class TestStorageView(unittest.TestCase): + def test_from_vec_len_to_vec(self): + v = StorageView.from_vec([1.0, 2.0, 3.5], "cpu") + self.assertEqual(v.len(), 3) + self.assertEqual(v.to_vec(), [1.0, 2.0, 3.5]) + + def test_full(self): + v = StorageView.full(2.5, 4, "cpu") + self.assertEqual(v.len(), 4) + self.assertEqual(v.to_vec(), [2.5, 2.5, 2.5, 2.5]) + + # unary ops: neg, sqrt, exp, log + + def test_neg_1d(self): + def test_cranberry(): + a = StorageView.from_vec(A_np.tolist(), "cpu") + out = a.neg() + return np.array(out.to_vec(), dtype=np.float32) + + def test_numpy(): + return -A_np + + np.testing.assert_allclose(test_cranberry(), test_numpy(), rtol, atol) + + def test_sqrt_1d(self): + # ensure positivity + data = (np.abs(A_np) + 1e-3).astype(np.float32) + + def test_cranberry(): + a = StorageView.from_vec(data.tolist(), "cpu") + out = a.sqrt() + return np.array(out.to_vec(), dtype=np.float32) + + def test_numpy(): + return np.sqrt(data) + + np.testing.assert_allclose(test_cranberry(), test_numpy(), rtol, atol) + + def test_exp_1d(self): + # clamp to avoid inf in exp + data = np.clip(A_np, -10, 10) + + def test_cranberry(): + a = StorageView.from_vec(data.tolist(), "cpu") + out = a.exp() + return np.array(out.to_vec(), dtype=np.float32) + + def test_numpy(): + return np.exp(data) + + np.testing.assert_allclose(test_cranberry(), test_numpy(), rtol, atol) + + def test_log_1d(self): + data = (np.abs(A_np) + 1e-3).astype(np.float32) + + def test_cranberry(): + a = StorageView.from_vec(data.tolist(), "cpu") + out = a.log() + return np.array(out.to_vec(), dtype=np.float32) + + def test_numpy(): + return np.log(data) + + np.testing.assert_allclose(test_cranberry(), test_numpy(), rtol, atol) + + # binary ops: add, sub, mul, div (1D and 2D contiguous) + + def test_add_1d(self): + def test_cranberry(): + a = StorageView.from_vec(A_np.tolist(), "cpu") + b = StorageView.from_vec(B_np.tolist(), "cpu") + out = a.add(b) + return np.array(out.to_vec(), dtype=np.float32) + + def test_numpy(): + return A_np + B_np + + np.testing.assert_allclose(test_cranberry(), test_numpy(), rtol, atol) + + def test_sub_1d(self): + def test_cranberry(): + a = StorageView.from_vec(A_np.tolist(), "cpu") + b = StorageView.from_vec(B_np.tolist(), "cpu") + out = a.sub(b) + return np.array(out.to_vec(), dtype=np.float32) + + def test_numpy(): + return A_np - B_np + + np.testing.assert_allclose(test_cranberry(), test_numpy(), rtol, atol) + + def test_mul_1d(self): + def test_cranberry(): + a = StorageView.from_vec(A_np.tolist(), "cpu") + b = StorageView.from_vec(B_np.tolist(), "cpu") + out = a.mul(b) + return np.array(out.to_vec(), dtype=np.float32) + + def test_numpy(): + return A_np * B_np + + np.testing.assert_allclose(test_cranberry(), test_numpy(), rtol, atol) + + def test_div_1d(self): + bnp = B_np.copy() + bnp[bnp == 0] = 1.0 + + def test_cranberry(): + a = StorageView.from_vec(A_np.tolist(), "cpu") + b = StorageView.from_vec(bnp.tolist(), "cpu") + out = a.div(b) + return np.array(out.to_vec(), dtype=np.float32) + + def test_numpy(): + return A_np / bnp + + np.testing.assert_allclose(test_cranberry(), test_numpy(), rtol, atol) + + def test_unary_2d_then_flatten_compare(self): + data = np.arange(N * M, dtype=np.float32) + + def test_cranberry(): + a = StorageView.from_vec(data.tolist(), "cpu").reshape([N, M]) + out = a.neg().exp().log() # identity-ish + return np.array(out.to_vec(), dtype=np.float32) + + def test_numpy(): + with np.errstate(divide="ignore", invalid="ignore"): + out = np.log(np.exp(-data.reshape(N, M))).reshape(-1) + return out + + np.testing.assert_allclose(test_cranberry(), test_numpy(), rtol, atol) + + def test_unary_random_shapes(self): + rng = np.random.default_rng(1337) + shapes = [(N * M,), (10, 10), (3, 4, 5)] + + for shape in shapes: + data = rng.standard_normal(np.prod(shape)).astype(np.float32) + # prepare variants for stability per op + pos = (np.abs(data) + 1e-3).astype(np.float32) # for sqrt/log + clip = np.clip(data, -10, 10) # for exp + + # neg + a = StorageView.from_vec(data.tolist(), "cpu") + if len(shape) > 1: + a = a.reshape(list(shape)) + out = np.array(a.neg().to_vec(), dtype=np.float32) + np.testing.assert_allclose(out, -data.reshape(-1), rtol, atol) + + # sqrt + a = StorageView.from_vec(pos.tolist(), "cpu") + if len(shape) > 1: + a = a.reshape(list(shape)) + out = np.array(a.sqrt().to_vec(), dtype=np.float32) + np.testing.assert_allclose(out, np.sqrt(pos).reshape(-1), rtol, atol) + + # exp + a = StorageView.from_vec(clip.tolist(), "cpu") + if len(shape) > 1: + a = a.reshape(list(shape)) + out = np.array(a.exp().to_vec(), dtype=np.float32) + np.testing.assert_allclose(out, np.exp(clip).reshape(-1), rtol, atol) + + # log + a = StorageView.from_vec(pos.tolist(), "cpu") + if len(shape) > 1: + a = a.reshape(list(shape)) + out = np.array(a.log().to_vec(), dtype=np.float32) + np.testing.assert_allclose(out, np.log(pos).reshape(-1), rtol, atol) + + def test_binary_random_shapes(self): + rng = np.random.default_rng(2024) + shapes = [(N * M,), (17, 11), (2, 3, 5)] + + for shape in shapes: + a_np = rng.standard_normal(np.prod(shape)).astype(np.float32) + b_np = rng.standard_normal(np.prod(shape)).astype(np.float32) + b_np_div = b_np.copy() + b_np_div[np.isclose(b_np_div, 0.0)] = 1.0 + + def make_sv(arr): + sv = StorageView.from_vec(arr.reshape(-1).tolist(), "cpu") + return sv.reshape(list(shape)) if len(shape) > 1 else sv + + # add + a = make_sv(a_np) + b = make_sv(b_np) + out = np.array(a.add(b).to_vec(), dtype=np.float32) + np.testing.assert_allclose(out, (a_np + b_np).reshape(-1), rtol, atol) + + # sub + a = make_sv(a_np) + b = make_sv(b_np) + out = np.array(a.sub(b).to_vec(), dtype=np.float32) + np.testing.assert_allclose(out, (a_np - b_np).reshape(-1), rtol, atol) + + # mul + a = make_sv(a_np) + b = make_sv(b_np) + out = np.array(a.mul(b).to_vec(), dtype=np.float32) + np.testing.assert_allclose(out, (a_np * b_np).reshape(-1), rtol, atol) + + # div + a = make_sv(a_np) + b = make_sv(b_np_div) + out = np.array(a.div(b).to_vec(), dtype=np.float32) + np.testing.assert_allclose(out, (a_np / b_np_div).reshape(-1), rtol, atol) + + def test_slice_then_unary(self): + base = np.linspace(-3, 3, 101, dtype=np.float32) + v = StorageView.from_vec(base.tolist(), "cpu") + s = v.slice(5, 77) + out = np.array(s.exp().to_vec(), dtype=np.float32) + np.testing.assert_allclose(out, np.exp(base[5:82]), rtol, atol) + + def test_slice_then_binary(self): + a_base = np.linspace(-2, 2, 111, dtype=np.float32) + b_base = np.linspace(3, -3, 111, dtype=np.float32) + a = StorageView.from_vec(a_base.tolist(), "cpu").slice(7, 55) + b = StorageView.from_vec(b_base.tolist(), "cpu").slice(9, 55) + out = np.array(a.mul(b).to_vec(), dtype=np.float32) + np.testing.assert_allclose(out, (a_base[7:62] * b_base[9:64]), rtol, atol) + + def test_numpy_input(self): + arr = np.array([1.0, 2.0, 3.0], dtype=np.float32) + v = StorageView.from_vec(arr, "cpu") + self.assertEqual(v.to_vec(), [1.0, 2.0, 3.0]) + + def test_unary_simd_remainder(self): + # length intentionally not divisible by SIMD width (64) + data = np.arange(65, dtype=np.float32) + v = StorageView.from_vec(data.tolist(), "cpu") + out = np.array(v.neg().to_vec(), dtype=np.float32) + np.testing.assert_allclose(out, -data, rtol, atol) + + def test_binary_simd_remainder(self): + a = np.arange(129, dtype=np.float32) + b = np.arange(129, dtype=np.float32) * -0.25 + va = StorageView.from_vec(a.tolist(), "cpu") + vb = StorageView.from_vec(b.tolist(), "cpu") + out = np.array(va.add(vb).to_vec(), dtype=np.float32) + np.testing.assert_allclose(out, a + b, rtol, atol) + + def test_binary_2d(self): + a_np = A_np.reshape(N, M) + b_np = B_np.reshape(N, M) + + def test_cranberry(): + a = StorageView.from_vec(a_np.reshape(-1).tolist(), "cpu").reshape([N, M]) + b = StorageView.from_vec(b_np.reshape(-1).tolist(), "cpu").reshape([N, M]) + out = a.mul(b).add(a).sub(b).div(a) + return np.array(out.to_vec(), dtype=np.float32) + + def test_numpy(): + out = (a_np * b_np + a_np - b_np) / a_np + return out.reshape(-1) + + np.testing.assert_allclose(test_cranberry(), test_numpy(), rtol, atol) + + # movement semantics and error paths + + def test_slice_1d(self): + base = np.arange(100, dtype=np.float32) + off, size = 7, 23 + + def test_cranberry(): + v = StorageView.from_vec(base.tolist(), "cpu") + s = v.slice(off, size) + return np.array(s.to_vec(), dtype=np.float32) + + def test_numpy(): + return base[off : off + size] + + np.testing.assert_allclose(test_cranberry(), test_numpy(), rtol, atol) + + def test_reshape_contiguous_to_vec(self): + v = StorageView.from_vec([1.0, 2.0, 3.0, 4.0], "cpu") + r = v.reshape([2, 2]) + self.assertEqual(r.to_vec(), [1.0, 2.0, 3.0, 4.0]) + + def test_expand_non_contiguous_to_vec_raises(self): + v = StorageView.from_vec([1.0, 2.0, 3.0], "cpu").reshape([1, 3]) + e = v.expand([2, 3]) + with self.assertRaises(ValueError): + _ = e.to_vec() + + def test_permute_non_contiguous_to_vec_raises(self): + v = StorageView.from_vec([1.0, 2.0, 3.0, 4.0], "cpu").reshape([2, 2]) + p = v.permute([1, 0]) + with self.assertRaises(ValueError): + _ = p.to_vec() + + def test_unary_on_non_contiguous_raises(self): + v = StorageView.from_vec([1.0, 4.0, 9.0, 16.0], "cpu").reshape([2, 2]).permute([1, 0]) + with self.assertRaises(RuntimeError): + _ = v.neg() + with self.assertRaises(RuntimeError): + _ = v.sqrt() + with self.assertRaises(RuntimeError): + _ = v.exp() + with self.assertRaises(RuntimeError): + _ = v.log() + + def test_binary_on_non_contiguous_raises(self): + a = StorageView.from_vec([1.0, 2.0, 3.0, 4.0], "cpu").reshape([2, 2]).permute([1, 0]) + b = StorageView.from_vec([1.0, 2.0, 3.0, 4.0], "cpu").reshape([2, 2]) + with self.assertRaises(RuntimeError): + _ = a.add(b) + with self.assertRaises(RuntimeError): + _ = a.sub(b) + with self.assertRaises(RuntimeError): + _ = a.mul(b) + with self.assertRaises(RuntimeError): + _ = a.div(b) + + def test_binary_shape_mismatch_raises(self): + a = StorageView.from_vec([1.0, 2.0, 3.0], "cpu") + b = StorageView.from_vec([4.0, 5.0], "cpu") + with self.assertRaises(RuntimeError): + _ = a.add(b) + + def test_add_device_mismatch_raises(self): + a = StorageView.from_vec([1.0, 2.0], "cpu") + b = StorageView.from_vec([1.0, 2.0], "metal") + with self.assertRaises(ValueError): + _ = a.add(b) + + def test_ops_not_implemented_other_device(self): + v = StorageView.from_vec([1.0, 2.0], "metal") + with self.assertRaises(NotImplementedError): + _ = v.neg() + with self.assertRaises(NotImplementedError): + _ = v.exp() + + +if __name__ == "__main__": + unittest.main()