Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: cargo fmt --all -- --check

- name: Run clippy
run: cargo clippy --workspace --all-features
run: cargo clippy

check_style_python:
name: Check file formatting and style (Python)
Expand Down Expand Up @@ -74,7 +74,7 @@ jobs:
- uses: Swatinem/rust-cache@v2

- name: Run tests
run: cargo test --workspace --all-features
run: cargo test

test_python:
name: Tests (Python)
Expand Down
92 changes: 92 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ pyo3 = { version = "0.24.1", features = ["extension-module"] }
rand = { version = "0.8.5" }
rand_distr = "0.4"
thiserror = "1"
cudarc = { version = "0.17", optional = true, features = ["driver", "nvrtc", "cuda-version-from-build-system"] }
once_cell = "1.19"

[features]
default = ["python"]
python = []
cuda = ["cudarc"]
abi3 = ["pyo3/abi3-py37", "generate-import-lib"]
generate-import-lib = ["pyo3/generate-import-lib"]

Expand Down
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ A small deep learning framework in Rust and Python

## Overview

Cranberry is an educational project exploring how a tensor library, automatic differentiation, and a Rust-backed storage layer fit together. The Python front-end intentionally stays simple while the Rust extension supplies fast contiguous kernels and view manipulation utilities. Everything currently targets CPU and 32-bit floating point tensors.
Cranberry is an educational project exploring how a tensor library, automatic differentiation, and a Rust-backed storage layer fit together. The Python front-end intentionally stays simple while the Rust extension supplies fast contiguous kernels and view manipulation utilities. Everything targets 32-bit floating point tensors and now offers optional CUDA acceleration for pointwise kernels alongside the CPU path.

## Highlights

- Python-first `Tensor` API backed by the `StorageView` PyO3 module.
- Reverse-mode autograd with topological traversal, supporting gradient tracking through broadcasting and reshape/expand/permute transforms.
- Contiguous CPU kernels for unary/binary ops plus sum/max reductions, with broadcasting handled in Python.
- Optional CUDA backend for contiguous unary and binary operations when an NVIDIA GPU and toolkit are available.
- Basic neural-network building blocks (`nn.Linear`, `nn.ReLU`, `nn.Sequential`) and stochastic gradient descent in `optim.SGD`.
- Visualization helpers for autograd graphs (`cranberry.features.visualize`) and an MNIST downloader with caching (`cranberry.features.datasets`).

Expand Down Expand Up @@ -56,11 +57,12 @@ Cranberry is an educational project exploring how a tensor library, automatic di
**Rust Extension**
- `StorageView` exposes contiguous tensor storage, reshaping, expanding, permuting, and random fills.
- CPU backend implements SIMD-accelerated unary/binary kernels and reduction routines.
- CUDA backend (via `cudarc`) mirrors the contiguous unary/binary kernels when a CUDA device is detected.
- Views currently support up to rank-4 tensors; non-contiguous reshape/permute paths are under construction.

## Limitations & Work in Progress

- CPU-only; GPU/Metal backends are stubbed out.
- CUDA backend currently covers only contiguous unary/binary kernels; Metal remains stubbed out.
- Autograd requires scalar losses and does not yet handle slicing/indexing/in-place mutations.
- Views must be contiguous for most kernels; slicing and advanced indexing are not implemented.
- Only `float32` tensors are supported; dtype promotion and mixed precision are future work.
Expand Down Expand Up @@ -98,6 +100,12 @@ Optional extras:
- `pip install -e .[datasets]` for download progress via `tqdm`.
- `pip install -e .[all]` to include every extra.

## CUDA backend

- Requires an NVIDIA driver and CUDA toolkit (NVRTC must be discoverable via `CUDA_HOME`, `CUDA_PATH`, or the default `/usr/local/cuda`).
- No separate `nvcc` build step is needed—the crate compiles its kernels at runtime using NVRTC.
- At runtime pass `device="cuda"` when creating tensors/storage; contiguous unary and binary ops will execute on the GPU and fall back with a runtime error if no device is present.

## Quickstart

```python
Expand Down
99 changes: 79 additions & 20 deletions src/backend/kernels_simd.rs → src/backend/cpu.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,70 @@
use std::simd::num::SimdFloat;
use std::simd::{f32x64, StdFloat};

use super::{
require_contiguous, require_same_numel, view_from_storage, Backend, BackendResult, BinaryOp,
UnaryOp,
};
use crate::core::{storage::StorageInner, view::View};
use crate::device::Device;

const CHUNK_SIZE: usize = 64;

pub mod unary_ops {
pub struct CpuBackend;

impl CpuBackend {
fn apply_unary(op: UnaryOp, input: &[f32], output: &mut [f32]) {
unary::apply(op, input, output);
}

fn apply_binary(op: BinaryOp, lhs: &[f32], rhs: &[f32], output: &mut [f32]) {
binary::apply(op, lhs, rhs, output);
}
}

impl Backend for CpuBackend {
fn unary(&self, op: UnaryOp, a: &View) -> BackendResult<View> {
require_contiguous(a)?;
let input = a.inner.as_slice(a.offset, a.numel());
let mut storage = StorageInner::new_full(0.0, a.numel(), Device::Cpu);
{
let output = storage.as_mut_slice(0, a.numel());
Self::apply_unary(op, input, output);
}
Ok(view_from_storage(storage, &a.shape))
}

fn binary(&self, op: BinaryOp, a: &View, b: &View) -> BackendResult<View> {
require_contiguous(a)?;
require_contiguous(b)?;
require_same_numel(a, b)?;
let lhs = a.inner.as_slice(a.offset, a.numel());
let rhs = b.inner.as_slice(b.offset, b.numel());
let mut storage = StorageInner::new_full(0.0, a.numel(), Device::Cpu);
{
let output = storage.as_mut_slice(0, a.numel());
Self::apply_binary(op, lhs, rhs, output);
}
Ok(view_from_storage(storage, &a.shape))
}
}

pub(crate) mod unary {
use super::*;
use std::ops::Neg;

pub fn neg(a: &[f32], b: &mut [f32]) {
assert!(a.len() == b.len());
pub fn apply(op: UnaryOp, input: &[f32], output: &mut [f32]) {
assert_eq!(input.len(), output.len());
match op {
UnaryOp::Neg => neg(input, output),
UnaryOp::Sqrt => sqrt(input, output),
UnaryOp::Exp => exp(input, output),
UnaryOp::Log => log(input, output),
UnaryOp::Relu => relu(input, output),
}
}

fn neg(a: &[f32], b: &mut [f32]) {
let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE);
let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE);
a_main
Expand All @@ -21,8 +77,7 @@ pub mod unary_ops {
.for_each(|(a, b)| *b = -a);
}

pub fn sqrt(a: &[f32], b: &mut [f32]) {
assert!(a.len() == b.len());
fn sqrt(a: &[f32], b: &mut [f32]) {
let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE);
let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE);
a_main
Expand All @@ -35,8 +90,7 @@ pub mod unary_ops {
.for_each(|(a, b)| *b = a.sqrt());
}

pub fn relu(a: &[f32], b: &mut [f32]) {
assert!(a.len() == b.len());
fn relu(a: &[f32], b: &mut [f32]) {
let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE);
let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE);
let zero = f32x64::splat(0.0);
Expand All @@ -50,8 +104,7 @@ pub mod unary_ops {
.for_each(|(a, b)| *b = a.max(0.0));
}

pub fn exp(a: &[f32], b: &mut [f32]) {
assert!(a.len() == b.len());
fn exp(a: &[f32], b: &mut [f32]) {
let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE);
let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE);
a_main
Expand All @@ -64,8 +117,7 @@ pub mod unary_ops {
.for_each(|(a, b)| *b = a.exp());
}

pub fn log(a: &[f32], b: &mut [f32]) {
assert!(a.len() == b.len());
fn log(a: &[f32], b: &mut [f32]) {
let (a_main, a_rem) = a.split_at(a.len() - a.len() % CHUNK_SIZE);
let (b_main, b_rem) = b.split_at_mut(b.len() - b.len() % CHUNK_SIZE);
a_main
Expand All @@ -79,12 +131,22 @@ pub mod unary_ops {
}
}

pub mod binary_ops {
pub(crate) mod binary {
use super::*;
use std::ops::{Add, Div, Mul, Sub};

pub fn add(a: &[f32], b: &[f32], c: &mut [f32]) {
assert!(a.len() == b.len() && b.len() == c.len());
pub fn apply(op: BinaryOp, lhs: &[f32], rhs: &[f32], output: &mut [f32]) {
assert_eq!(lhs.len(), rhs.len());
assert_eq!(rhs.len(), output.len());
match op {
BinaryOp::Add => add(lhs, rhs, output),
BinaryOp::Sub => sub(lhs, rhs, output),
BinaryOp::Mul => mul(lhs, rhs, output),
BinaryOp::Div => div(lhs, rhs, output),
}
}

fn add(a: &[f32], b: &[f32], c: &mut [f32]) {
let main = a.len() - a.len() % CHUNK_SIZE;
let (a_main, a_rem) = a.split_at(main);
let (b_main, b_rem) = b.split_at(main);
Expand All @@ -105,8 +167,7 @@ pub mod binary_ops {
.for_each(|((a, b), c)| *c = a + b);
}

pub fn sub(a: &[f32], b: &[f32], c: &mut [f32]) {
assert!(a.len() == b.len() && b.len() == c.len());
fn sub(a: &[f32], b: &[f32], c: &mut [f32]) {
let main = a.len() - a.len() % CHUNK_SIZE;
let (a_main, a_rem) = a.split_at(main);
let (b_main, b_rem) = b.split_at(main);
Expand All @@ -127,8 +188,7 @@ pub mod binary_ops {
.for_each(|((a, b), c)| *c = a - b);
}

pub fn mul(a: &[f32], b: &[f32], c: &mut [f32]) {
assert!(a.len() == b.len() && b.len() == c.len());
fn mul(a: &[f32], b: &[f32], c: &mut [f32]) {
let main = a.len() - a.len() % CHUNK_SIZE;
let (a_main, a_rem) = a.split_at(main);
let (b_main, b_rem) = b.split_at(main);
Expand All @@ -149,8 +209,7 @@ pub mod binary_ops {
.for_each(|((a, b), c)| *c = a * b);
}

pub fn div(a: &[f32], b: &[f32], c: &mut [f32]) {
assert!(a.len() == b.len() && b.len() == c.len());
fn div(a: &[f32], b: &[f32], c: &mut [f32]) {
let main = a.len() - a.len() % CHUNK_SIZE;
let (a_main, a_rem) = a.split_at(main);
let (b_main, b_rem) = b.split_at(main);
Expand Down
Loading