diff --git a/catgrad/src/interpreter/backend/candle.rs b/catgrad/src/interpreter/backend/candle.rs index ed68f37f..63dc77f1 100644 --- a/catgrad/src/interpreter/backend/candle.rs +++ b/catgrad/src/interpreter/backend/candle.rs @@ -32,7 +32,113 @@ use candle_core::{ // ============================================================================ #[derive(Clone, Debug)] -pub struct CandleTensor(pub Tensor); +struct DeferredIndex0 { + indices: Tensor, + dim: usize, +} + +#[derive(Clone, Debug)] +pub struct CandleTensor { + tensor: Tensor, + deferred_index0: Option, +} + +impl CandleTensor { + fn from_materialized(tensor: Tensor) -> Self { + Self { + tensor, + deferred_index0: None, + } + } + + fn from_indexed_select(tensor: &Tensor, dim: usize, indices: &Tensor) -> Self { + Self { + tensor: tensor.clone(), + deferred_index0: Some(DeferredIndex0 { + indices: indices.flatten_all().unwrap(), + dim, + }), + } + } + + pub fn materialize(&self) -> Tensor { + match &self.deferred_index0 { + None => self.tensor.clone(), + Some(DeferredIndex0 { indices, dim }) => { + CandleBackend::index_tensor_materialized(&self.tensor, *dim, indices) + } + } + } + + fn transpose(&self, dim0: usize, dim1: usize) -> Self { + let tensor = self.tensor.transpose(dim0, dim1).unwrap(); + let deferred_index0 = + self.deferred_index0 + .as_ref() + .map(|DeferredIndex0 { indices, dim }| DeferredIndex0 { + indices: indices.clone(), + dim: if *dim == dim0 { + dim1 + } else if *dim == dim1 { + dim0 + } else { + *dim + }, + }); + Self { + tensor, + deferred_index0, + } + } + + fn slice(&self, dim: usize, start: usize, len: usize) -> Self { + match &self.deferred_index0 { + None => Self::from_materialized(CandleBackend::slice_tensor_materialized( + &self.tensor, + dim, + start, + len, + )), + Some(DeferredIndex0 { + indices, + dim: indexed_dim, + }) if dim == *indexed_dim => Self { + tensor: self.tensor.clone(), + deferred_index0: Some(DeferredIndex0 { + indices: indices.narrow(0, start, len).unwrap(), + dim: *indexed_dim, + }), + }, + Some(DeferredIndex0 { + indices, + dim: indexed_dim, + }) => Self { + tensor: self.tensor.narrow(dim, start, len).unwrap(), + deferred_index0: Some(DeferredIndex0 { + indices: indices.clone(), + dim: *indexed_dim, + }), + }, + } + } + + fn shape(&self) -> Shape { + match &self.deferred_index0 { + None => Shape(self.tensor.dims().to_vec()), + Some(DeferredIndex0 { indices, dim }) => { + let mut dims = self.tensor.dims().to_vec(); + dims[*dim] = indices.dims1().unwrap(); + Shape(dims) + } + } + } +} + +impl From for CandleTensor { + fn from(value: Tensor) -> Self { + Self::from_materialized(value) + } +} #[derive(Clone, Debug)] pub struct CandleBackend { @@ -81,21 +187,31 @@ impl Backend for CandleBackend { fn to_vec(&self, vec: TaggedTensor) -> TaggedVec { match vec { - TaggedTensor::F32([x]) => TaggedVec::F32(x.0.flatten_all().unwrap().to_vec1().unwrap()), - TaggedTensor::F16([x]) => TaggedVec::F16(x.0.flatten_all().unwrap().to_vec1().unwrap()), + TaggedTensor::F32([x]) => { + let x = x.materialize(); + TaggedVec::F32(x.flatten_all().unwrap().to_vec1().unwrap()) + } + TaggedTensor::F16([x]) => { + let x = x.materialize(); + TaggedVec::F16(x.flatten_all().unwrap().to_vec1().unwrap()) + } TaggedTensor::BF16([x]) => { - TaggedVec::BF16(x.0.flatten_all().unwrap().to_vec1().unwrap()) + let x = x.materialize(); + TaggedVec::BF16(x.flatten_all().unwrap().to_vec1().unwrap()) + } + TaggedTensor::U32([x]) => { + let x = x.materialize(); + TaggedVec::U32(x.flatten_all().unwrap().to_vec1().unwrap()) } - TaggedTensor::U32([x]) => TaggedVec::U32(x.0.flatten_all().unwrap().to_vec1().unwrap()), } } fn format_tensor(&self, tensor: &TaggedTensor) -> String { match tensor { - TaggedTensor::F32([x]) => format!("{}", &x.0), - TaggedTensor::F16([x]) => format!("{}", &x.0), - TaggedTensor::BF16([x]) => format!("{}", &x.0), - TaggedTensor::U32([x]) => format!("{}", &x.0), + TaggedTensor::F32([x]) => format!("{}", x.materialize()), + TaggedTensor::F16([x]) => format!("{}", x.materialize()), + TaggedTensor::BF16([x]) => format!("{}", x.materialize()), + TaggedTensor::U32([x]) => format!("{}", x.materialize()), } } @@ -104,20 +220,20 @@ impl Backend for CandleBackend { match target_dtype { Dtype::F32 => { let tensor = Tensor::zeros(dims, DType::F32, &self.device).unwrap(); - TaggedTensor::F32([CandleTensor(tensor)]) + TaggedTensor::F32([tensor.into()]) } Dtype::F16 => { let tensor = Tensor::zeros(dims, DType::F16, &self.device).unwrap(); - TaggedTensor::F16([CandleTensor(tensor)]) + TaggedTensor::F16([tensor.into()]) } Dtype::BF16 => { self.ensure_dtype_supported(DType::BF16); let tensor = Tensor::zeros(dims, DType::BF16, &self.device).unwrap(); - TaggedTensor::BF16([CandleTensor(tensor)]) + TaggedTensor::BF16([tensor.into()]) } Dtype::U32 => { let tensor = Tensor::zeros(dims, DType::U32, &self.device).unwrap(); - TaggedTensor::U32([CandleTensor(tensor)]) + TaggedTensor::U32([tensor.into()]) } } } @@ -130,7 +246,7 @@ impl Backend for CandleBackend { let dims: &[usize] = &shape.0; let tensor = Tensor::from_vec(data, dims, &self.device).map_err(|_| BackendError::ShapeError)?; - Ok(TaggedTensor::F32([CandleTensor(tensor)])) + Ok(TaggedTensor::F32([tensor.into()])) } fn ndarray_from_vec_f16( @@ -141,7 +257,7 @@ impl Backend for CandleBackend { let dims: &[usize] = &shape.0; let tensor = Tensor::from_vec(data, dims, &self.device).map_err(|_| BackendError::ShapeError)?; - Ok(TaggedTensor::F16([CandleTensor(tensor)])) + Ok(TaggedTensor::F16([tensor.into()])) } fn ndarray_from_vec_bf16( @@ -153,7 +269,7 @@ impl Backend for CandleBackend { let dims: &[usize] = &shape.0; let tensor = Tensor::from_vec(data, dims, &self.device).map_err(|_| BackendError::ShapeError)?; - Ok(TaggedTensor::BF16([CandleTensor(tensor)])) + Ok(TaggedTensor::BF16([tensor.into()])) } fn ndarray_from_vec_u32( @@ -164,7 +280,7 @@ impl Backend for CandleBackend { let dims: &[usize] = &shape.0; let tensor = Tensor::from_vec(data, dims, &self.device).map_err(|_| BackendError::ShapeError)?; - Ok(TaggedTensor::U32([CandleTensor(tensor)])) + Ok(TaggedTensor::U32([tensor.into()])) } fn cast(&self, x: TaggedTensor, target_dtype: Dtype) -> TaggedTensor { @@ -176,7 +292,7 @@ impl Backend for CandleBackend { TaggedTensor::F32([arr]) | TaggedTensor::F16([arr]) | TaggedTensor::BF16([arr]) - | TaggedTensor::U32([arr]) => arr.0, + | TaggedTensor::U32([arr]) => arr.materialize(), }; let target_dtype = match target_dtype { Dtype::F32 => DType::F32, @@ -185,7 +301,7 @@ impl Backend for CandleBackend { Dtype::U32 => DType::U32, }; self.ensure_dtype_supported(target_dtype); - let tensor = CandleTensor(tensor.to_dtype(target_dtype).unwrap()); + let tensor: CandleTensor = tensor.to_dtype(target_dtype).unwrap().into(); match target_dtype { DType::F32 => TaggedTensor::F32([tensor]), DType::F16 => TaggedTensor::F16([tensor]), @@ -198,139 +314,139 @@ impl Backend for CandleBackend { fn matmul(&self, lhs: TaggedTensorTuple) -> TaggedTensor { use TaggedTensorTuple::*; match lhs { - F32([x, y]) => F32([CandleTensor(Self::batched_matmul(x.0, y.0))]), - F16([x, y]) => F16([CandleTensor(Self::batched_matmul(x.0, y.0))]), - BF16([x, y]) => BF16([CandleTensor(Self::batched_matmul(x.0, y.0))]), - U32([x, y]) => U32([CandleTensor(Self::batched_matmul(x.0, y.0))]), + F32([x, y]) => F32([Self::matmul_tensors(x, y)]), + F16([x, y]) => F16([Self::matmul_tensors(x, y)]), + BF16([x, y]) => BF16([Self::matmul_tensors(x, y)]), + U32([x, y]) => U32([Self::matmul_tensors(x, y)]), } } fn add(&self, lhs: TaggedTensorTuple) -> TaggedTensor { use TaggedTensorTuple::*; match lhs { - F32([x, y]) => F32([Self::add(x, y)]), - F16([x, y]) => F16([Self::add(x, y)]), - BF16([x, y]) => BF16([Self::add(x, y)]), - U32([x, y]) => U32([Self::add(x, y)]), + F32([x, y]) => F32([Self::binary_eager(x, y, Self::add)]), + F16([x, y]) => F16([Self::binary_eager(x, y, Self::add)]), + BF16([x, y]) => BF16([Self::binary_eager(x, y, Self::add)]), + U32([x, y]) => U32([Self::binary_eager(x, y, Self::add)]), } } fn sub(&self, lhs: TaggedTensorTuple) -> TaggedTensor { use TaggedTensorTuple::*; match lhs { - F32([x, y]) => F32([Self::sub(x, y)]), - F16([x, y]) => F16([Self::sub(x, y)]), - BF16([x, y]) => BF16([Self::sub(x, y)]), - U32([x, y]) => U32([Self::sub(x, y)]), + F32([x, y]) => F32([Self::binary_eager(x, y, Self::sub)]), + F16([x, y]) => F16([Self::binary_eager(x, y, Self::sub)]), + BF16([x, y]) => BF16([Self::binary_eager(x, y, Self::sub)]), + U32([x, y]) => U32([Self::binary_eager(x, y, Self::sub)]), } } fn mul(&self, lhs: TaggedTensorTuple) -> TaggedTensor { use TaggedTensorTuple::*; match lhs { - F32([x, y]) => F32([Self::mul(x, y)]), - F16([x, y]) => F16([Self::mul(x, y)]), - BF16([x, y]) => BF16([Self::mul(x, y)]), - U32([x, y]) => U32([Self::mul(x, y)]), + F32([x, y]) => F32([Self::binary_eager(x, y, Self::mul)]), + F16([x, y]) => F16([Self::binary_eager(x, y, Self::mul)]), + BF16([x, y]) => BF16([Self::binary_eager(x, y, Self::mul)]), + U32([x, y]) => U32([Self::binary_eager(x, y, Self::mul)]), } } fn div(&self, lhs: TaggedTensorTuple) -> TaggedTensor { use TaggedTensorTuple::*; match lhs { - F32([x, y]) => F32([Self::div(x, y)]), - F16([x, y]) => F16([Self::div(x, y)]), - BF16([x, y]) => BF16([Self::div(x, y)]), - U32([x, y]) => U32([Self::div(x, y)]), + F32([x, y]) => F32([Self::binary_eager(x, y, Self::div)]), + F16([x, y]) => F16([Self::binary_eager(x, y, Self::div)]), + BF16([x, y]) => BF16([Self::binary_eager(x, y, Self::div)]), + U32([x, y]) => U32([Self::binary_eager(x, y, Self::div)]), } } fn lt(&self, lhs: TaggedTensorTuple) -> TaggedTensor { use TaggedTensorTuple::*; match lhs { - F32([x, y]) => F32([Self::lt(x, y)]), - F16([x, y]) => F16([Self::lt(x, y)]), - BF16([x, y]) => BF16([Self::lt(x, y)]), - U32([x, y]) => U32([Self::lt(x, y)]), + F32([x, y]) => F32([Self::binary_eager(x, y, Self::lt)]), + F16([x, y]) => F16([Self::binary_eager(x, y, Self::lt)]), + BF16([x, y]) => BF16([Self::binary_eager(x, y, Self::lt)]), + U32([x, y]) => U32([Self::binary_eager(x, y, Self::lt)]), } } fn gt(&self, lhs: TaggedTensorTuple) -> TaggedTensor { use TaggedTensorTuple::*; match lhs { - F32([x, y]) => F32([Self::gt(x, y)]), - F16([x, y]) => F16([Self::gt(x, y)]), - BF16([x, y]) => BF16([Self::gt(x, y)]), - U32([x, y]) => U32([Self::gt(x, y)]), + F32([x, y]) => F32([Self::binary_eager(x, y, Self::gt)]), + F16([x, y]) => F16([Self::binary_eager(x, y, Self::gt)]), + BF16([x, y]) => BF16([Self::binary_eager(x, y, Self::gt)]), + U32([x, y]) => U32([Self::binary_eager(x, y, Self::gt)]), } } fn gte(&self, lhs: TaggedTensorTuple) -> TaggedTensor { use TaggedTensorTuple::*; match lhs { - F32([x, y]) => F32([Self::gte(x, y)]), - F16([x, y]) => F16([Self::gte(x, y)]), - BF16([x, y]) => BF16([Self::gte(x, y)]), - U32([x, y]) => U32([Self::gte(x, y)]), + F32([x, y]) => F32([Self::binary_eager(x, y, Self::gte)]), + F16([x, y]) => F16([Self::binary_eager(x, y, Self::gte)]), + BF16([x, y]) => BF16([Self::binary_eager(x, y, Self::gte)]), + U32([x, y]) => U32([Self::binary_eager(x, y, Self::gte)]), } } fn lte(&self, lhs: TaggedTensorTuple) -> TaggedTensor { use TaggedTensorTuple::*; match lhs { - F32([x, y]) => F32([Self::lte(x, y)]), - F16([x, y]) => F16([Self::lte(x, y)]), - BF16([x, y]) => BF16([Self::lte(x, y)]), - U32([x, y]) => U32([Self::lte(x, y)]), + F32([x, y]) => F32([Self::binary_eager(x, y, Self::lte)]), + F16([x, y]) => F16([Self::binary_eager(x, y, Self::lte)]), + BF16([x, y]) => BF16([Self::binary_eager(x, y, Self::lte)]), + U32([x, y]) => U32([Self::binary_eager(x, y, Self::lte)]), } } fn eq(&self, lhs: TaggedTensorTuple) -> TaggedTensor { use TaggedTensorTuple::*; match lhs { - F32([x, y]) => F32([Self::eq(x, y)]), - F16([x, y]) => F16([Self::eq(x, y)]), - BF16([x, y]) => BF16([Self::eq(x, y)]), - U32([x, y]) => U32([Self::eq(x, y)]), + F32([x, y]) => F32([Self::binary_eager(x, y, Self::eq)]), + F16([x, y]) => F16([Self::binary_eager(x, y, Self::eq)]), + BF16([x, y]) => BF16([Self::binary_eager(x, y, Self::eq)]), + U32([x, y]) => U32([Self::binary_eager(x, y, Self::eq)]), } } fn where_cond(&self, args: TaggedTensorTuple) -> TaggedTensor { use TaggedTensorTuple::*; match args { - F32([mask, x, y]) => F32([Self::where_cond(mask, x, y)]), - F16([mask, x, y]) => F16([Self::where_cond(mask, x, y)]), - BF16([mask, x, y]) => BF16([Self::where_cond(mask, x, y)]), - U32([mask, x, y]) => U32([Self::where_cond(mask, x, y)]), + F32([mask, x, y]) => F32([Self::ternary_eager(mask, x, y, Self::where_cond)]), + F16([mask, x, y]) => F16([Self::ternary_eager(mask, x, y, Self::where_cond)]), + BF16([mask, x, y]) => BF16([Self::ternary_eager(mask, x, y, Self::where_cond)]), + U32([mask, x, y]) => U32([Self::ternary_eager(mask, x, y, Self::where_cond)]), } } fn pow(&self, lhs: TaggedTensorTuple) -> TaggedTensor { use TaggedTensorTuple::*; match lhs { - F32([x, y]) => F32([Self::pow(x, y)]), - F16([x, y]) => F16([Self::pow(x, y)]), - BF16([x, y]) => BF16([Self::pow(x, y)]), - U32([x, y]) => U32([Self::pow(x, y)]), + F32([x, y]) => F32([Self::binary_eager(x, y, Self::pow)]), + F16([x, y]) => F16([Self::binary_eager(x, y, Self::pow)]), + BF16([x, y]) => BF16([Self::binary_eager(x, y, Self::pow)]), + U32([x, y]) => U32([Self::binary_eager(x, y, Self::pow)]), } } fn neg(&self, x: TaggedTensor) -> TaggedTensor { use TaggedTensorTuple::*; match x { - F32([arr]) => F32([Self::neg(arr)]), - F16([arr]) => F16([Self::neg(arr)]), - BF16([arr]) => BF16([Self::neg(arr)]), - U32([arr]) => U32([Self::neg(arr)]), + F32([arr]) => F32([Self::unary_eager(arr, Self::neg)]), + F16([arr]) => F16([Self::unary_eager(arr, Self::neg)]), + BF16([arr]) => BF16([Self::unary_eager(arr, Self::neg)]), + U32([arr]) => U32([Self::unary_eager(arr, Self::neg)]), } } fn sin(&self, x: TaggedTensor) -> TaggedTensor { use TaggedTensorTuple::*; match x { - F32([arr]) => F32([Self::sin(arr)]), - F16([arr]) => F16([Self::sin(arr)]), - BF16([arr]) => BF16([Self::sin(arr)]), + F32([arr]) => F32([Self::unary_eager(arr, Self::sin)]), + F16([arr]) => F16([Self::unary_eager(arr, Self::sin)]), + BF16([arr]) => BF16([Self::unary_eager(arr, Self::sin)]), _ => panic!("Invalid type for sin"), } } @@ -338,9 +454,9 @@ impl Backend for CandleBackend { fn cos(&self, x: TaggedTensor) -> TaggedTensor { use TaggedTensorTuple::*; match x { - F32([arr]) => F32([Self::cos(arr)]), - F16([arr]) => F16([Self::cos(arr)]), - BF16([arr]) => BF16([Self::cos(arr)]), + F32([arr]) => F32([Self::unary_eager(arr, Self::cos)]), + F16([arr]) => F16([Self::unary_eager(arr, Self::cos)]), + BF16([arr]) => BF16([Self::unary_eager(arr, Self::cos)]), _ => panic!("Invalid type for cos"), } } @@ -348,9 +464,9 @@ impl Backend for CandleBackend { fn log(&self, x: TaggedTensor) -> TaggedTensor { use TaggedTensorTuple::*; match x { - F32([arr]) => F32([Self::log(arr)]), - F16([arr]) => F16([Self::log(arr)]), - BF16([arr]) => BF16([Self::log(arr)]), + F32([arr]) => F32([Self::unary_eager(arr, Self::log)]), + F16([arr]) => F16([Self::unary_eager(arr, Self::log)]), + BF16([arr]) => BF16([Self::unary_eager(arr, Self::log)]), _ => panic!("Invalid type for log"), } } @@ -358,9 +474,9 @@ impl Backend for CandleBackend { fn floor(&self, x: TaggedTensor) -> TaggedTensor { use TaggedTensorTuple::*; match x { - F32([arr]) => F32([Self::floor(arr)]), - F16([arr]) => F16([Self::floor(arr)]), - BF16([arr]) => BF16([Self::floor(arr)]), + F32([arr]) => F32([Self::unary_eager(arr, Self::floor)]), + F16([arr]) => F16([Self::unary_eager(arr, Self::floor)]), + BF16([arr]) => BF16([Self::unary_eager(arr, Self::floor)]), _ => panic!("Invalid type for floor"), } } @@ -368,30 +484,30 @@ impl Backend for CandleBackend { fn max(&self, x: TaggedTensor) -> TaggedTensor { use TaggedTensorTuple::*; match x { - F32([arr]) => F32([Self::max(arr)]), - F16([arr]) => F16([Self::max(arr)]), - BF16([arr]) => BF16([Self::max(arr)]), - U32([arr]) => U32([Self::max(arr)]), + F32([arr]) => F32([Self::unary_eager(arr, Self::max)]), + F16([arr]) => F16([Self::unary_eager(arr, Self::max)]), + BF16([arr]) => BF16([Self::unary_eager(arr, Self::max)]), + U32([arr]) => U32([Self::unary_eager(arr, Self::max)]), } } fn sum(&self, x: TaggedTensor) -> TaggedTensor { use TaggedTensorTuple::*; match x { - F32([arr]) => F32([Self::sum(arr)]), - F16([arr]) => F16([Self::sum(arr)]), - BF16([arr]) => BF16([Self::sum(arr)]), - U32([arr]) => U32([Self::sum(arr)]), + F32([arr]) => F32([Self::unary_eager(arr, Self::sum)]), + F16([arr]) => F16([Self::unary_eager(arr, Self::sum)]), + BF16([arr]) => BF16([Self::unary_eager(arr, Self::sum)]), + U32([arr]) => U32([Self::unary_eager(arr, Self::sum)]), } } fn argmax(&self, x: TaggedTensor) -> TaggedTensor { use TaggedTensorTuple::*; match x { - F32([arr]) => U32([Self::argmax(arr)]), - F16([arr]) => U32([Self::argmax(arr)]), - BF16([arr]) => U32([Self::argmax(arr)]), - U32([arr]) => U32([Self::argmax(arr)]), + F32([arr]) => U32([Self::unary_eager(arr, Self::argmax)]), + F16([arr]) => U32([Self::unary_eager(arr, Self::argmax)]), + BF16([arr]) => U32([Self::unary_eager(arr, Self::argmax)]), + U32([arr]) => U32([Self::unary_eager(arr, Self::argmax)]), } } @@ -399,16 +515,19 @@ impl Backend for CandleBackend { use TaggedTensorTuple::*; match x { F32([arr]) => { - let (values, indices) = Self::topk_f32(arr.0, k); - (F32([CandleTensor(values)]), U32([CandleTensor(indices)])) + let arr = arr.materialize(); + let (values, indices) = Self::topk_f32(&arr, k); + (F32([values.into()]), U32([indices.into()])) } F16([arr]) => { - let (values, indices) = Self::topk_f32(arr.0, k); - (F16([CandleTensor(values)]), U32([CandleTensor(indices)])) + let arr = arr.materialize(); + let (values, indices) = Self::topk_f32(&arr, k); + (F16([values.into()]), U32([indices.into()])) } BF16([arr]) => { - let (values, indices) = Self::topk_f32(arr.0, k); - (BF16([CandleTensor(values)]), U32([CandleTensor(indices)])) + let arr = arr.materialize(); + let (values, indices) = Self::topk_f32(&arr, k); + (BF16([values.into()]), U32([indices.into()])) } _ => panic!("Unsupported type for topk"), } @@ -417,72 +536,96 @@ impl Backend for CandleBackend { fn broadcast(&self, x: TaggedTensor, shape: Shape) -> TaggedTensor { use TaggedTensorTuple::*; match x { - F32([arr]) => F32([CandleTensor(Self::broadcast_tensor(arr.0, shape))]), - F16([arr]) => F16([CandleTensor(Self::broadcast_tensor(arr.0, shape))]), - BF16([arr]) => BF16([CandleTensor(Self::broadcast_tensor(arr.0, shape))]), - U32([arr]) => U32([CandleTensor(Self::broadcast_tensor(arr.0, shape))]), + F32([arr]) => { + let arr = arr.materialize(); + F32([Self::broadcast_tensor(&arr, shape).into()]) + } + F16([arr]) => { + let arr = arr.materialize(); + F16([Self::broadcast_tensor(&arr, shape).into()]) + } + BF16([arr]) => { + let arr = arr.materialize(); + BF16([Self::broadcast_tensor(&arr, shape).into()]) + } + U32([arr]) => { + let arr = arr.materialize(); + U32([Self::broadcast_tensor(&arr, shape).into()]) + } } } fn reshape(&self, x: TaggedTensor, new_shape: Shape) -> TaggedTensor { use TaggedTensorTuple::*; match x { - F32([arr]) => F32([CandleTensor(Self::reshape_tensor(arr.0, new_shape))]), - F16([arr]) => F16([CandleTensor(Self::reshape_tensor(arr.0, new_shape))]), - BF16([arr]) => BF16([CandleTensor(Self::reshape_tensor(arr.0, new_shape))]), - U32([arr]) => U32([CandleTensor(Self::reshape_tensor(arr.0, new_shape))]), + F32([arr]) => { + let arr = arr.materialize(); + F32([Self::reshape_tensor(&arr, new_shape).into()]) + } + F16([arr]) => { + let arr = arr.materialize(); + F16([Self::reshape_tensor(&arr, new_shape).into()]) + } + BF16([arr]) => { + let arr = arr.materialize(); + BF16([Self::reshape_tensor(&arr, new_shape).into()]) + } + U32([arr]) => { + let arr = arr.materialize(); + U32([Self::reshape_tensor(&arr, new_shape).into()]) + } } } fn transpose(&self, x: TaggedTensor, dim0: usize, dim1: usize) -> TaggedTensor { use TaggedTensorTuple::*; match x { - F32([arr]) => F32([CandleTensor(Self::transpose_tensor(arr.0, dim0, dim1))]), - F16([arr]) => F16([CandleTensor(Self::transpose_tensor(arr.0, dim0, dim1))]), - BF16([arr]) => BF16([CandleTensor(Self::transpose_tensor(arr.0, dim0, dim1))]), - U32([arr]) => U32([CandleTensor(Self::transpose_tensor(arr.0, dim0, dim1))]), + F32([arr]) => F32([arr.transpose(dim0, dim1)]), + F16([arr]) => F16([arr.transpose(dim0, dim1)]), + BF16([arr]) => BF16([arr.transpose(dim0, dim1)]), + U32([arr]) => U32([arr.transpose(dim0, dim1)]), } } fn arange(&self, end: usize) -> TaggedTensor { use TaggedTensorTuple::*; let r = Tensor::arange(0, end as u32, &self.device).unwrap(); - U32([CandleTensor(r)]) + U32([r.into()]) } fn to_bool(&self, x: TaggedTensor) -> bool { match x { - TaggedTensor::F32([x]) => { - x.0.gt(0.0) - .ok() - .and_then(|t| t.max_all().ok()) - .and_then(|m| m.to_scalar::().ok()) - .map(|s| s == 1) - .unwrap_or(false) - } - TaggedTensor::F16([x]) => { - x.0.gt(0.0) - .ok() - .and_then(|t| t.max_all().ok()) - .and_then(|m| m.to_scalar::().ok()) - .map(|s| s == 1) - .unwrap_or(false) - } - TaggedTensor::BF16([x]) => { - x.0.gt(0.0) - .ok() - .and_then(|t| t.max_all().ok()) - .and_then(|m| m.to_scalar::().ok()) - .map(|s| s == 1) - .unwrap_or(false) - } - TaggedTensor::U32([x]) => { - x.0.ne(0u32) - .ok() - .and_then(|t| t.max_all().ok()) - .and_then(|m| m.to_scalar::().ok()) - .map(|s| s == 1) - .unwrap_or(false) - } + TaggedTensor::F32([x]) => x + .materialize() + .gt(0.0) + .ok() + .and_then(|t| t.max_all().ok()) + .and_then(|m| m.to_scalar::().ok()) + .map(|s| s == 1) + .unwrap_or(false), + TaggedTensor::F16([x]) => x + .materialize() + .gt(0.0) + .ok() + .and_then(|t| t.max_all().ok()) + .and_then(|m| m.to_scalar::().ok()) + .map(|s| s == 1) + .unwrap_or(false), + TaggedTensor::BF16([x]) => x + .materialize() + .gt(0.0) + .ok() + .and_then(|t| t.max_all().ok()) + .and_then(|m| m.to_scalar::().ok()) + .map(|s| s == 1) + .unwrap_or(false), + TaggedTensor::U32([x]) => x + .materialize() + .ne(0u32) + .ok() + .and_then(|t| t.max_all().ok()) + .and_then(|m| m.to_scalar::().ok()) + .map(|s| s == 1) + .unwrap_or(false), } } @@ -494,18 +637,10 @@ impl Backend for CandleBackend { ) -> TaggedTensor { use TaggedTensorTuple::*; match (x, indices) { - (F32([arr]), U32([indices])) => { - F32([CandleTensor(Self::index_tensor(arr.0, dim, indices.0))]) - } - (F16([arr]), U32([indices])) => { - F16([CandleTensor(Self::index_tensor(arr.0, dim, indices.0))]) - } - (BF16([arr]), U32([indices])) => { - BF16([CandleTensor(Self::index_tensor(arr.0, dim, indices.0))]) - } - (U32([arr]), U32([indices])) => { - U32([CandleTensor(Self::index_tensor(arr.0, dim, indices.0))]) - } + (F32([arr]), U32([indices])) => F32([Self::index_tensor(arr, dim, indices)]), + (F16([arr]), U32([indices])) => F16([Self::index_tensor(arr, dim, indices)]), + (BF16([arr]), U32([indices])) => BF16([Self::index_tensor(arr, dim, indices)]), + (U32([arr]), U32([indices])) => U32([Self::index_tensor(arr, dim, indices)]), _ => panic!("Invalid index type"), } } @@ -519,20 +654,20 @@ impl Backend for CandleBackend { ) -> TaggedTensor { use TaggedTensorTuple::*; match x { - F32([arr]) => F32([CandleTensor(Self::slice_tensor(arr.0, dim, start, len))]), - F16([arr]) => F16([CandleTensor(Self::slice_tensor(arr.0, dim, start, len))]), - BF16([arr]) => BF16([CandleTensor(Self::slice_tensor(arr.0, dim, start, len))]), - U32([arr]) => U32([CandleTensor(Self::slice_tensor(arr.0, dim, start, len))]), + F32([arr]) => F32([arr.slice(dim, start, len)]), + F16([arr]) => F16([arr.slice(dim, start, len)]), + BF16([arr]) => BF16([arr.slice(dim, start, len)]), + U32([arr]) => U32([arr.slice(dim, start, len)]), } } fn compare(&self, x: TaggedTensorTuple) -> bool { use TaggedTensorTuple::*; match x { - F32([a, b]) => Self::compare_tensors(&a.0, &b.0), - F16([a, b]) => Self::compare_tensors(&a.0, &b.0), - BF16([a, b]) => Self::compare_tensors(&a.0, &b.0), - U32([a, b]) => Self::compare_tensors(&a.0, &b.0), + F32([a, b]) => Self::compare_tensors(&a.materialize(), &b.materialize()), + F16([a, b]) => Self::compare_tensors(&a.materialize(), &b.materialize()), + BF16([a, b]) => Self::compare_tensors(&a.materialize(), &b.materialize()), + U32([a, b]) => Self::compare_tensors(&a.materialize(), &b.materialize()), } } @@ -544,19 +679,27 @@ impl Backend for CandleBackend { ) -> TaggedTensor { use TaggedTensorTuple::*; match (x, y) { - (F32([a]), F32([b])) => F32([CandleTensor(Self::concat_tensors(&a.0, &b.0, dim))]), - (F16([a]), F16([b])) => F16([CandleTensor(Self::concat_tensors(&a.0, &b.0, dim))]), - (BF16([a]), BF16([b])) => BF16([CandleTensor(Self::concat_tensors(&a.0, &b.0, dim))]), - (U32([a]), U32([b])) => U32([CandleTensor(Self::concat_tensors(&a.0, &b.0, dim))]), + (F32([a]), F32([b])) => { + F32([Self::concat_tensors(&a.materialize(), &b.materialize(), dim).into()]) + } + (F16([a]), F16([b])) => { + F16([Self::concat_tensors(&a.materialize(), &b.materialize(), dim).into()]) + } + (BF16([a]), BF16([b])) => { + BF16([Self::concat_tensors(&a.materialize(), &b.materialize(), dim).into()]) + } + (U32([a]), U32([b])) => { + U32([Self::concat_tensors(&a.materialize(), &b.materialize(), dim).into()]) + } _ => panic!("Incompatible types for concatenation"), } } } impl CandleBackend { - fn float_tensor_to_f32_vec(tensor: Tensor) -> Vec { + fn float_tensor_to_f32_vec(tensor: &Tensor) -> Vec { let tensor = match tensor.dtype() { - DType::F32 => tensor, + DType::F32 => tensor.clone(), DType::F16 | DType::BF16 => tensor.to_dtype(DType::F32).unwrap(), dtype => panic!("Unsupported float tensor dtype {dtype:?}"), }; @@ -603,7 +746,7 @@ impl CandleBackend { Tensor::cat(&[a, b], dim).unwrap() } - fn reshape_tensor(tensor: Tensor, new_shape: Shape) -> Tensor { + fn reshape_tensor(tensor: &Tensor, new_shape: Shape) -> Tensor { let dims_s = tensor.dims(); let dims_t = new_shape.0.clone(); @@ -617,140 +760,230 @@ impl CandleBackend { } } - fn transpose_tensor(tensor: Tensor, dim0: usize, dim1: usize) -> Tensor { - tensor.transpose(dim0, dim1).unwrap() + fn unary_eager(x: CandleTensor, op: fn(&Tensor) -> CandleTensor) -> CandleTensor { + let x = x.materialize(); + op(&x) + } + + fn binary_eager( + x: CandleTensor, + y: CandleTensor, + op: fn(&Tensor, &Tensor) -> CandleTensor, + ) -> CandleTensor { + let x = x.materialize(); + let y = y.materialize(); + op(&x, &y) + } + + fn ternary_eager( + x: CandleTensor, + y: CandleTensor, + z: CandleTensor, + op: fn(&Tensor, &Tensor, &Tensor) -> CandleTensor, + ) -> CandleTensor { + let x = x.materialize(); + let y = y.materialize(); + let z = z.materialize(); + op(&x, &y, &z) + } + + fn matmul_tensors(lhs: CandleTensor, rhs: CandleTensor) -> CandleTensor { + match (&lhs.deferred_index0, &rhs.deferred_index0) { + (None, Some(DeferredIndex0 { indices, dim })) if *dim == 0 => { + Self::indexed_batched_matmul_rhs(&lhs.tensor, &rhs.tensor, indices).into() + } + _ => { + let lhs = lhs.materialize(); + let rhs = rhs.materialize(); + Self::batched_matmul(&lhs, &rhs).into() + } + } + } + + fn indexed_batched_matmul_rhs(lhs: &Tensor, rhs: &Tensor, indices: &Tensor) -> Tensor { + let num_experts = rhs.dim(0).unwrap(); + let mut positions_by_expert = vec![Vec::::new(); num_experts]; + + for (position, expert_id) in indices + .flatten_all() + .unwrap() + .to_vec1::() + .unwrap() + .into_iter() + .enumerate() + { + positions_by_expert[expert_id as usize].push(position as u32); + } + + let mut out_dims = lhs.dims().to_vec(); + let rhs_out_dim = *rhs.dims().last().unwrap(); + *out_dims.last_mut().unwrap() = rhs_out_dim; + let mut out = Tensor::zeros(&*out_dims, lhs.dtype(), lhs.device()).unwrap(); + + for (expert_id, positions) in positions_by_expert.into_iter().enumerate() { + if positions.is_empty() { + continue; + } + + let positions_len = positions.len(); + let positions = Tensor::from_vec(positions, positions_len, lhs.device()).unwrap(); + let lhs_chunk = lhs.index_select(&positions, 0).unwrap(); + let rhs_chunk = rhs + .narrow(0, expert_id, 1) + .unwrap() + .squeeze(0) + .unwrap() + .unsqueeze(0) + .unwrap() + .broadcast_as(vec![positions_len, rhs.dims()[1], rhs_out_dim]) + .unwrap(); + let result_chunk = Self::batched_matmul(&lhs_chunk, &rhs_chunk); + out = out.index_add(&positions, &result_chunk, 0).unwrap(); + } + + out + } + + fn index_tensor(input: CandleTensor, dim: usize, indices: CandleTensor) -> CandleTensor { + let indices = indices.materialize(); + if dim == 0 && input.deferred_index0.is_none() { + CandleTensor::from_indexed_select(&input.tensor, dim, &indices) + } else { + let tensor = input.materialize(); + Self::index_tensor_materialized(&tensor, dim, &indices).into() + } } - fn index_tensor(tensor: Tensor, dim: usize, indices: Tensor) -> Tensor { + fn index_tensor_materialized(tensor: &Tensor, dim: usize, indices: &Tensor) -> Tensor { let idx = indices.flatten_all().unwrap(); tensor.index_select(&idx, dim).unwrap() } - fn slice_tensor(tensor: Tensor, dim: usize, start: usize, len: usize) -> Tensor { + fn slice_tensor_materialized(tensor: &Tensor, dim: usize, start: usize, len: usize) -> Tensor { tensor.narrow(dim, start, len).unwrap() } - fn broadcast_tensor(tensor: Tensor, shape: Shape) -> Tensor { + fn broadcast_tensor(tensor: &Tensor, shape: Shape) -> Tensor { tensor.broadcast_as(shape.0).unwrap() } - fn add(x: CandleTensor, y: CandleTensor) -> CandleTensor { - if x.0.dims() != y.0.dims() { + fn add(x: &Tensor, y: &Tensor) -> CandleTensor { + if x.dims() != y.dims() { panic!("Shape mismatch in operation"); } - CandleTensor((&x.0 + &y.0).unwrap()) + ((x + y).unwrap()).into() } - fn sub(x: CandleTensor, y: CandleTensor) -> CandleTensor { - if x.0.dims() != y.0.dims() { + fn sub(x: &Tensor, y: &Tensor) -> CandleTensor { + if x.dims() != y.dims() { panic!("Shape mismatch in operation"); } - CandleTensor((&x.0 - &y.0).unwrap()) + ((x - y).unwrap()).into() } - fn mul(x: CandleTensor, y: CandleTensor) -> CandleTensor { - if x.0.dims() != y.0.dims() { + fn mul(x: &Tensor, y: &Tensor) -> CandleTensor { + if x.dims() != y.dims() { panic!("Shape mismatch in operation"); } - CandleTensor((&x.0 * &y.0).unwrap()) + ((x * y).unwrap()).into() } - fn div(x: CandleTensor, y: CandleTensor) -> CandleTensor { - if x.0.dims() != y.0.dims() { + fn div(x: &Tensor, y: &Tensor) -> CandleTensor { + if x.dims() != y.dims() { panic!("Shape mismatch in operation"); } - CandleTensor((&x.0 / &y.0).unwrap()) + ((x / y).unwrap()).into() } - fn lt(x: CandleTensor, y: CandleTensor) -> CandleTensor { - let dtype = x.0.dtype(); - if x.0.dims() != y.0.dims() { + fn lt(x: &Tensor, y: &Tensor) -> CandleTensor { + let dtype = x.dtype(); + if x.dims() != y.dims() { panic!("Shape mismatch in operation"); } - CandleTensor(x.0.lt(&y.0).unwrap().to_dtype(dtype).unwrap()) + x.lt(y).unwrap().to_dtype(dtype).unwrap().into() } - fn gt(x: CandleTensor, y: CandleTensor) -> CandleTensor { - let dtype = x.0.dtype(); - if x.0.dims() != y.0.dims() { + fn gt(x: &Tensor, y: &Tensor) -> CandleTensor { + let dtype = x.dtype(); + if x.dims() != y.dims() { panic!("Shape mismatch in operation"); } - CandleTensor(x.0.gt(&y.0).unwrap().to_dtype(dtype).unwrap()) + x.gt(y).unwrap().to_dtype(dtype).unwrap().into() } - fn gte(x: CandleTensor, y: CandleTensor) -> CandleTensor { - let dtype = x.0.dtype(); - if x.0.dims() != y.0.dims() { + fn gte(x: &Tensor, y: &Tensor) -> CandleTensor { + let dtype = x.dtype(); + if x.dims() != y.dims() { panic!("Shape mismatch in operation"); } - CandleTensor(x.0.ge(&y.0).unwrap().to_dtype(dtype).unwrap()) + x.ge(y).unwrap().to_dtype(dtype).unwrap().into() } - fn lte(x: CandleTensor, y: CandleTensor) -> CandleTensor { - let dtype = x.0.dtype(); - if x.0.dims() != y.0.dims() { + fn lte(x: &Tensor, y: &Tensor) -> CandleTensor { + let dtype = x.dtype(); + if x.dims() != y.dims() { panic!("Shape mismatch in operation"); } - CandleTensor(x.0.le(&y.0).unwrap().to_dtype(dtype).unwrap()) + x.le(y).unwrap().to_dtype(dtype).unwrap().into() } - fn eq(x: CandleTensor, y: CandleTensor) -> CandleTensor { - let dtype = x.0.dtype(); - if x.0.dims() != y.0.dims() { + fn eq(x: &Tensor, y: &Tensor) -> CandleTensor { + let dtype = x.dtype(); + if x.dims() != y.dims() { panic!("Shape mismatch in operation"); } - CandleTensor(x.0.eq(&y.0).unwrap().to_dtype(dtype).unwrap()) + x.eq(y).unwrap().to_dtype(dtype).unwrap().into() } - fn where_cond(mask: CandleTensor, x: CandleTensor, y: CandleTensor) -> CandleTensor { - let mask = match mask.0.dtype() { - DType::F32 | DType::F16 | DType::BF16 => mask.0.gt(0.).unwrap(), - DType::U32 => mask.0.ne(0u32).unwrap(), - _ => mask.0, // already U8 (boolean) or other type + fn where_cond(mask: &Tensor, x: &Tensor, y: &Tensor) -> CandleTensor { + let mask = match mask.dtype() { + DType::F32 | DType::F16 | DType::BF16 => mask.gt(0.).unwrap(), + DType::U32 => mask.ne(0u32).unwrap(), + _ => mask.clone(), // already U8 (boolean) or other type }; - CandleTensor(mask.where_cond(&x.0, &y.0).unwrap()) + mask.where_cond(x, y).unwrap().into() } - fn neg(x: CandleTensor) -> CandleTensor { - CandleTensor(x.0.neg().unwrap()) + fn neg(x: &Tensor) -> CandleTensor { + x.neg().unwrap().into() } - fn sin(x: CandleTensor) -> CandleTensor { - CandleTensor(x.0.sin().unwrap()) + fn sin(x: &Tensor) -> CandleTensor { + x.sin().unwrap().into() } - fn cos(x: CandleTensor) -> CandleTensor { - CandleTensor(x.0.cos().unwrap()) + fn cos(x: &Tensor) -> CandleTensor { + x.cos().unwrap().into() } - fn log(x: CandleTensor) -> CandleTensor { - CandleTensor(x.0.log().unwrap()) + fn log(x: &Tensor) -> CandleTensor { + x.log().unwrap().into() } - fn floor(x: CandleTensor) -> CandleTensor { - CandleTensor(x.0.floor().unwrap()) + fn floor(x: &Tensor) -> CandleTensor { + x.floor().unwrap().into() } // Candle's pow function does not support negative base and silently generates NaNs // so we do element-wise powf https://github.com/huggingface/candle/issues/1640 - fn pow(x: CandleTensor, y: CandleTensor) -> CandleTensor { - if x.0.dims() != y.0.dims() { + fn pow(x: &Tensor, y: &Tensor) -> CandleTensor { + if x.dims() != y.dims() { panic!("Shape mismatch in operation"); } - let dtype = x.0.dtype(); - let shape = x.0.dims().to_vec(); - let device = x.0.device().clone(); + let dtype = x.dtype(); + let shape = x.dims().to_vec(); + let device = x.device().clone(); // Convert tensors to vectors for element-wise powf operation - let x_vec = Self::float_tensor_to_f32_vec(x.0); - let y_vec = Self::float_tensor_to_f32_vec(y.0); + let x_vec = Self::float_tensor_to_f32_vec(x); + let y_vec = Self::float_tensor_to_f32_vec(y); // Perform element-wise powf let result_vec: Vec = x_vec @@ -766,35 +999,43 @@ impl CandleBackend { } else { result_tensor.to_dtype(dtype).unwrap() }; - CandleTensor(result_tensor) + result_tensor.into() } - fn sum(x: CandleTensor) -> CandleTensor { - CandleTensor(x.0.sum_keepdim(D::Minus1).unwrap()) + fn sum(x: &Tensor) -> CandleTensor { + x.sum_keepdim(D::Minus1).unwrap().into() } - fn max(x: CandleTensor) -> CandleTensor { - CandleTensor(x.0.max_keepdim(D::Minus1).unwrap()) + fn max(x: &Tensor) -> CandleTensor { + x.max_keepdim(D::Minus1).unwrap().into() } - fn argmax(x: CandleTensor) -> CandleTensor { - CandleTensor(x.0.argmax_keepdim(D::Minus1).unwrap()) + fn argmax(x: &Tensor) -> CandleTensor { + x.argmax_keepdim(D::Minus1).unwrap().into() } - fn topk_f32(tensor: Tensor, k: usize) -> (Tensor, Tensor) { + fn topk_f32(tensor: &Tensor, k: usize) -> (Tensor, Tensor) { let (values, indices) = tensor.sort_last_dim(false).unwrap(); let topk_indices = indices.narrow(D::Minus1, 0, k).unwrap(); let topk_values = values.narrow(D::Minus1, 0, k).unwrap(); (topk_values, topk_indices) } - pub fn batched_matmul(lhs: Tensor, rhs: Tensor) -> Tensor { - match lhs.matmul(&rhs) { + fn contiguous_if_needed(tensor: &Tensor) -> Tensor { + if tensor.is_contiguous() { + tensor.clone() + } else { + tensor.contiguous().unwrap() + } + } + + pub fn batched_matmul(lhs: &Tensor, rhs: &Tensor) -> Tensor { + match lhs.matmul(rhs) { Ok(result) => result, - // On error retry with contiguous inputs + // On error retry with contiguous inputs. Err(_) => { - let lhs = lhs.contiguous().unwrap(); - let rhs = rhs.contiguous().unwrap(); + let lhs = Self::contiguous_if_needed(lhs); + let rhs = Self::contiguous_if_needed(rhs); lhs.matmul(&rhs).unwrap() } } @@ -803,7 +1044,7 @@ impl CandleBackend { impl BackendTensorOps for CandleTensor { fn shape(&self) -> Shape { - Shape(self.0.dims().to_vec()) + self.shape() } } @@ -836,7 +1077,7 @@ fn test_batched_matmul() { .reshape(&[2, 3, 2, 1]) .unwrap(); - let result = CandleBackend::batched_matmul(lhs, rhs); + let result = CandleBackend::batched_matmul(&lhs, &rhs); // Expected shape: [2, 3, 2, 1] assert_eq!(result.dims(), &[2, 3, 2, 1]); @@ -859,3 +1100,82 @@ fn test_batched_matmul() { ); } } + +#[test] +fn test_indexed_select_to_vec_matches_materialized_gather() { + let tensor = Tensor::new( + &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], + &candle_core::Device::Cpu, + ) + .unwrap() + .reshape(&[3, 2]) + .unwrap(); + let indices = Tensor::new(&[2u32, 0], &candle_core::Device::Cpu).unwrap(); + + let expected = CandleBackend::index_tensor_materialized(&tensor, 0, &indices); + let actual = CandleTensor::from_indexed_select(&tensor, 0, &indices).materialize(); + + assert_eq!( + actual.flatten_all().unwrap().to_vec1::().unwrap(), + expected.flatten_all().unwrap().to_vec1::().unwrap() + ); +} + +#[test] +fn test_indexed_select_slice_matches_materialized_gather() { + let tensor = Tensor::new( + &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + &candle_core::Device::Cpu, + ) + .unwrap() + .reshape(&[4, 2]) + .unwrap(); + let indices = Tensor::new(&[3u32, 1, 2], &candle_core::Device::Cpu).unwrap(); + + let expected = CandleBackend::index_tensor_materialized(&tensor, 0, &indices) + .narrow(0, 1, 2) + .unwrap(); + let actual = CandleTensor::from_indexed_select(&tensor, 0, &indices) + .slice(0, 1, 2) + .materialize(); + + assert_eq!( + actual.flatten_all().unwrap().to_vec1::().unwrap(), + expected.flatten_all().unwrap().to_vec1::().unwrap() + ); +} + +#[test] +fn test_indexed_select_rhs_matmul_matches_materialized_gather() { + let lhs = Tensor::new( + &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], + &candle_core::Device::Cpu, + ) + .unwrap() + .reshape(&[3, 1, 2]) + .unwrap(); + let rhs = Tensor::new( + &[ + 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ], + &candle_core::Device::Cpu, + ) + .unwrap() + .reshape(&[3, 2, 2]) + .unwrap(); + let indices = Tensor::new(&[2u32, 0, 2], &candle_core::Device::Cpu).unwrap(); + + let expected_rhs = CandleBackend::index_tensor_materialized(&rhs, 0, &indices); + let expected = CandleBackend::batched_matmul(&lhs, &expected_rhs); + let actual = CandleBackend::matmul_tensors( + lhs.into(), + CandleTensor::from_indexed_select(&rhs, 0, &indices), + ) + .materialize(); + + assert_eq!(actual.dims(), expected.dims()); + assert_eq!( + actual.flatten_all().unwrap().to_vec1::().unwrap(), + expected.flatten_all().unwrap().to_vec1::().unwrap() + ); +} diff --git a/catgrad/tests/test_interpreter_candle.rs b/catgrad/tests/test_interpreter_candle.rs index 37fd72eb..cc20e47d 100644 --- a/catgrad/tests/test_interpreter_candle.rs +++ b/catgrad/tests/test_interpreter_candle.rs @@ -1,8 +1,8 @@ #![cfg(feature = "candle-backend")] use catgrad::category::core::Shape; -use catgrad::interpreter::backend::Backend; use catgrad::interpreter::backend::candle::CandleBackend; +use catgrad::interpreter::backend::{Backend, BackendTensorOps}; use catgrad::interpreter::{TaggedTensor, TaggedTensorTuple, TaggedVec, Value}; // ============================================================================ @@ -20,7 +20,7 @@ fn test_candle_backend_basic_operations() { TaggedTensor::F32([arr]) => arr, _ => panic!("Expected F32"), }; - assert_eq!(zeros.0.shape().dims(), &[2, 3]); + assert_eq!(zeros.shape().0, &[2, 3]); // Test tensor creation from slice let data = vec![1.0f32, 2.0, 3.0, 4.0]; @@ -31,7 +31,7 @@ fn test_candle_backend_basic_operations() { TaggedTensor::F32([arr]) => arr, _ => panic!("Expected F32"), }; - assert_eq!(tensor.0.shape().dims(), &[2, 2]); + assert_eq!(tensor.shape().0, &[2, 2]); } #[test] @@ -61,7 +61,7 @@ fn test_candle_backend_arithmetic() { let result = backend.add(TaggedTensorTuple::F32([tensor1, tensor2])); match result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 2]); + assert_eq!(arr.shape().0, &[2, 2]); let values = backend.to_vec(TaggedTensor::F32([arr])); if let TaggedVec::F32(v) = values { assert_eq!(v, vec![3.0, 5.0, 7.0, 9.0]); @@ -93,7 +93,7 @@ fn test_candle_backend_arithmetic() { let result = backend.mul(TaggedTensorTuple::F32([tensor3, tensor4])); match result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 2]); + assert_eq!(arr.shape().0, &[2, 2]); let values = backend.to_vec(TaggedTensor::F32([arr])); if let TaggedVec::F32(v) = values { assert_eq!(v, vec![2.0, 6.0, 12.0, 20.0]); @@ -131,7 +131,7 @@ fn test_candle_backend_subtraction() { let result = backend.sub(TaggedTensorTuple::F32([tensor1, tensor2])); match result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 2]); + assert_eq!(arr.shape().0, &[2, 2]); let values = backend.to_vec(TaggedTensor::F32([arr])); if let TaggedVec::F32(v) = values { assert_eq!(v, vec![9.0, 6.0, 3.0, 0.0]); @@ -164,7 +164,7 @@ fn test_candle_backend_subtraction() { let result = backend.sub(TaggedTensorTuple::U32([tensor3, tensor4])); match result { TaggedTensor::U32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 2]); + assert_eq!(arr.shape().0, &[2, 2]); let values = backend.to_vec(TaggedTensor::U32([arr])); if let TaggedVec::U32(v) = values { assert_eq!(v, vec![9, 6, 3, 0]); @@ -193,7 +193,7 @@ fn test_candle_backend_max() { let result = backend.max(TaggedTensor::F32([tensor])); match result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 1]); + assert_eq!(arr.shape().0, &[2, 1]); let values = backend.to_vec(TaggedTensor::F32([arr])); if let TaggedVec::F32(v) = values { assert_eq!(v, vec![5.0, 8.0]); @@ -217,7 +217,7 @@ fn test_candle_backend_max() { let result = backend.max(TaggedTensor::U32([tensor_u32])); match result { TaggedTensor::U32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 1]); + assert_eq!(arr.shape().0, &[2, 1]); let values = backend.to_vec(TaggedTensor::U32([arr])); if let TaggedVec::U32(v) = values { assert_eq!(v, vec![5, 3]); @@ -247,7 +247,7 @@ fn test_candle_backend_argmax() { match result { TaggedTensor::U32([arr]) => { println!("argmax result: {:?}", arr); - assert_eq!(arr.0.shape().dims(), &[2, 1]); + assert_eq!(arr.shape().0, &[2, 1]); let values = backend.to_vec(TaggedTensor::U32([arr])); if let TaggedVec::U32(v) = values { assert_eq!(v, vec![1, 1]); @@ -272,7 +272,7 @@ fn test_candle_backend_argmax() { match result { TaggedTensor::U32([arr]) => { println!("argmax result: {:?}", arr); - assert_eq!(arr.0.shape().dims(), &[2, 1]); + assert_eq!(arr.shape().0, &[2, 1]); let values = backend.to_vec(TaggedTensor::U32([arr])); if let TaggedVec::U32(v) = values { assert_eq!(v, vec![1, 0]); @@ -301,7 +301,7 @@ fn test_candle_backend_sum() { let result = backend.sum(TaggedTensor::F32([tensor])); match result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 1]); + assert_eq!(arr.shape().0, &[2, 1]); let values = backend.to_vec(TaggedTensor::F32([arr])); if let TaggedVec::F32(v) = values { assert_eq!(v, vec![6.0, 15.0]); @@ -325,7 +325,7 @@ fn test_candle_backend_sum() { let result = backend.sum(TaggedTensor::U32([tensor_u32])); match result { TaggedTensor::U32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 1]); + assert_eq!(arr.shape().0, &[2, 1]); let values = backend.to_vec(TaggedTensor::U32([arr])); if let TaggedVec::U32(v) = values { assert_eq!(v, vec![6, 15]); @@ -363,7 +363,7 @@ fn test_candle_backend_matmul() { let result = backend.matmul(TaggedTensorTuple::F32([tensor1, tensor2])); match result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 2]); + assert_eq!(arr.shape().0, &[2, 2]); let values = backend.to_vec(TaggedTensor::F32([arr])); if let TaggedVec::F32(v) = values { assert_eq!(v, vec![22.0, 28.0, 49.0, 64.0]); @@ -391,7 +391,7 @@ fn test_candle_backend_reshape() { let reshaped = backend.reshape(TaggedTensor::F32([tensor]), Shape(vec![3, 2])); match reshaped { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[3, 2]); + assert_eq!(arr.shape().0, &[3, 2]); } _ => panic!("Expected F32 result"), } @@ -417,7 +417,7 @@ fn test_candle_backend_cast() { ); match casted { TaggedTensor::U32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 2]); + assert_eq!(arr.shape().0, &[2, 2]); let values = backend.to_vec(TaggedTensor::U32([arr])); if let TaggedVec::U32(v) = values { assert_eq!(v, vec![1, 2, 3, 4]); @@ -455,7 +455,7 @@ fn test_candle_backend_division() { let result = backend.div(TaggedTensorTuple::F32([tensor1, tensor2])); match result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 2]); + assert_eq!(arr.shape().0, &[2, 2]); let values = backend.to_vec(TaggedTensor::F32([arr])); if let TaggedVec::F32(v) = values { assert_eq!(v, vec![3.0, 2.0, 2.0, 4.0]); @@ -488,7 +488,7 @@ fn test_candle_backend_division() { let result = backend.div(TaggedTensorTuple::U32([tensor3, tensor4])); match result { TaggedTensor::U32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 2]); + assert_eq!(arr.shape().0, &[2, 2]); let values = backend.to_vec(TaggedTensor::U32([arr])); if let TaggedVec::U32(v) = values { assert_eq!(v, vec![3, 2, 2, 4]); @@ -526,7 +526,7 @@ fn test_candle_backend_power() { let result = backend.pow(TaggedTensorTuple::F32([tensor1, tensor2])); match result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 2]); + assert_eq!(arr.shape().0, &[2, 2]); let values = backend.to_vec(TaggedTensor::F32([arr])); if let TaggedVec::F32(v) = values { assert_eq!(v, vec![4.0, 9.0, 16.0, 25.0]); @@ -555,7 +555,7 @@ fn test_candle_backend_log() { let result = backend.log(TaggedTensor::F32([tensor])); match result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 2]); + assert_eq!(arr.shape().0, &[2, 2]); let values = backend.to_vec(TaggedTensor::F32([arr])); if let TaggedVec::F32(v) = values { let expected = [1.0f32, 2.0, 3.0, 4.0] @@ -590,8 +590,13 @@ fn test_candle_backend_floor() { let result = backend.floor(TaggedTensor::F32([tensor])); match result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 2]); - let result_data = arr.0.flatten_all().unwrap().to_vec1::().unwrap(); + assert_eq!(arr.shape().0, &[2, 2]); + let result_data = arr + .materialize() + .flatten_all() + .unwrap() + .to_vec1::() + .unwrap(); assert_eq!(result_data, vec![1.0, 2.0, -4.0, 4.0]); } _ => panic!("Expected F32 result"), @@ -615,7 +620,7 @@ fn test_candle_backend_negation() { let result = backend.neg(TaggedTensor::F32([tensor])); match result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 2]); + assert_eq!(arr.shape().0, &[2, 2]); let values = backend.to_vec(TaggedTensor::F32([arr])); if let TaggedVec::F32(v) = values { assert_eq!(v, vec![-1.0, 2.0, -3.0, 4.0]); @@ -645,7 +650,7 @@ fn test_candle_backend_broadcast() { let broadcasted = backend.broadcast(TaggedTensor::F32([tensor]), Shape(vec![1, 2, 2])); match broadcasted { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[1, 2, 2]); + assert_eq!(arr.shape().0, &[1, 2, 2]); } _ => panic!("Expected F32 result"), } @@ -661,7 +666,7 @@ fn test_candle_backend_broadcast() { let broadcasted = backend.broadcast(TaggedTensor::F32([tensor]), Shape(vec![5, 2, 2])); match broadcasted { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[5, 2, 2]); + assert_eq!(arr.shape().0, &[5, 2, 2]); } _ => panic!("Expected F32 result"), } @@ -681,7 +686,7 @@ fn test_candle_backend_broadcast() { backend.broadcast(TaggedTensor::U32([tensor_u32]), Shape(vec![2, 1, 2, 2])); match broadcasted_u32 { TaggedTensor::U32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[2, 1, 2, 2]); + assert_eq!(arr.shape().0, &[2, 1, 2, 2]); } _ => panic!("Expected U32 result"), } @@ -707,7 +712,7 @@ fn test_candle_backend_broadcast_bad_shape() { let broadcasted = backend.broadcast(TaggedTensor::F32([tensor]), Shape(vec![2, 2, 2])); match broadcasted { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[1, 2, 2]); + assert_eq!(arr.shape().0, &[1, 2, 2]); } _ => panic!("Expected F32 result"), } @@ -867,7 +872,7 @@ fn test_candle_backend_empty_tensor() { TaggedTensor::F32([arr]) => arr, _ => panic!("Expected F32"), }; - assert_eq!(scalar.0.shape().dims(), &[] as &[usize]); + assert_eq!(scalar.shape().0, &[] as &[usize]); // Test zeros with single element let single_tagged = backend.zeros(Shape(vec![1]), Dtype::F32); @@ -875,7 +880,7 @@ fn test_candle_backend_empty_tensor() { TaggedTensor::F32([arr]) => arr, _ => panic!("Expected F32"), }; - assert_eq!(single.0.shape().dims(), &[1]); + assert_eq!(single.shape().0, &[1]); } #[test] @@ -914,12 +919,7 @@ fn test_candle_backend_single_element_operations() { ] { match result { TaggedTensor::F32([arr]) => { - assert_eq!( - arr.0.shape().dims(), - &[1], - "{} result should have shape [1]", - name - ); + assert_eq!(arr.shape().0, &[1], "{} result should have shape [1]", name); } _ => panic!("Expected F32 result for {}", name), } @@ -933,11 +933,7 @@ fn test_candle_backend_single_element_operations() { // Test negation (preserves shape) match neg_result { TaggedTensor::F32([arr]) => { - assert_eq!( - arr.0.shape().dims(), - &[1], - "neg result should have shape [1]" - ); + assert_eq!(arr.shape().0, &[1], "neg result should have shape [1]"); } _ => panic!("Expected F32 result for neg"), } @@ -945,22 +941,14 @@ fn test_candle_backend_single_element_operations() { // Test max and sum match max_result { TaggedTensor::F32([arr]) => { - assert_eq!( - arr.0.shape().dims(), - &[1], - "max result should have shape []" - ); + assert_eq!(arr.shape().0, &[1], "max result should have shape []"); } _ => panic!("Expected F32 result for max"), } match sum_result { TaggedTensor::F32([arr]) => { - assert_eq!( - arr.0.shape().dims(), - &[1], - "sum result should have shape []" - ); + assert_eq!(arr.shape().0, &[1], "sum result should have shape []"); } _ => panic!("Expected F32 result for sum"), } @@ -985,7 +973,7 @@ fn test_candle_backend_large_tensor() { let result = backend.add(TaggedTensorTuple::F32([tensor.clone(), tensor.clone()])); match result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[10, 10]); + assert_eq!(arr.shape().0, &[10, 10]); } _ => panic!("Expected F32 result"), } @@ -996,14 +984,14 @@ fn test_candle_backend_large_tensor() { match sum_result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[10, 1]); + assert_eq!(arr.shape().0, &[10, 1]); } _ => panic!("Expected F32 result for sum"), } match max_result { TaggedTensor::F32([arr]) => { - assert_eq!(arr.0.shape().dims(), &[10, 1]); + assert_eq!(arr.shape().0, &[10, 1]); } _ => panic!("Expected F32 result for max"), } @@ -1154,7 +1142,12 @@ fn test_candle_interpreter_exp() { assert!( allclose_f32( - &actual.0.flatten_all().unwrap().to_vec1().unwrap(), + &actual + .materialize() + .flatten_all() + .unwrap() + .to_vec1() + .unwrap(), &expected, 1e-5, 1e-8