diff --git a/crates/piston-core/src/indexer.rs b/crates/piston-core/src/indexer.rs new file mode 100644 index 00000000..e845ab14 --- /dev/null +++ b/crates/piston-core/src/indexer.rs @@ -0,0 +1,142 @@ +// Adapted from Candle: https://github.com/huggingface/candle/blob/main/candle-core/src/indexer.rs +use crate::OpTensor; +use anyhow::Error; +use std::ops::{ + Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, +}; + +impl OpTensor { + fn index(&self, indexers: &[TensorIndexer]) -> Result { + let mut x = self.clone(); + let dims = self.shape().as_slice(); + let mut current_dim = 0; + for (i, indexer) in indexers.iter().enumerate() { + x = match indexer { + TensorIndexer::Select(n) => x.narrow(current_dim, *n, 1)?.squeeze(current_dim)?, + TensorIndexer::Narrow(left_bound, right_bound) => { + let start = match left_bound { + Bound::Included(n) => *n, + Bound::Excluded(n) => *n + 1, + Bound::Unbounded => 0, + }; + let stop = match right_bound { + Bound::Included(n) => *n + 1, + Bound::Excluded(n) => *n, + Bound::Unbounded => dims[i], + }; + let out = x.narrow(current_dim, start, stop.saturating_sub(start))?; + current_dim += 1; + out + } + TensorIndexer::IndexSelect(indexes) => { + if indexes.rank() != 1 { + anyhow::bail!("multi-dimensional tensor indexing is not supported") + } + if indexes.device() != x.device() { + anyhow::bail!("indexing device mismatch: index tensor is on {:?} but input tensor is on {:?}", indexes.device(), x.device()) + } + let out = x.index_select(indexes.clone(), current_dim)?; + current_dim += 1; + out + } + TensorIndexer::Err(e) => anyhow::bail!("indexing error {e:?}"), + }; + } + Ok(x) + } +} + +#[derive(Debug)] +/// Generic structure used to index a slice of the tensor +pub enum TensorIndexer { + /// This selects the elements for which an index has some specific value. + Select(usize), + /// This is a regular slice, purely indexing a chunk of the tensor + Narrow(Bound, Bound), + /// Indexing via a 1d tensor + IndexSelect(OpTensor), + Err(Error), +} + +impl From for TensorIndexer { + fn from(index: usize) -> Self { + TensorIndexer::Select(index) + } +} + +impl From<&OpTensor> for TensorIndexer { + fn from(tensor: &OpTensor) -> Self { + TensorIndexer::IndexSelect(tensor.clone()) + } +} + +trait RB: RangeBounds {} +impl RB for Range {} +impl RB for RangeFrom {} +impl RB for RangeFull {} +impl RB for RangeInclusive {} +impl RB for RangeTo {} +impl RB for RangeToInclusive {} + +impl From for TensorIndexer { + fn from(range: T) -> Self { + use std::ops::Bound::*; + let start = match range.start_bound() { + Included(idx) => Included(*idx), + Excluded(idx) => Excluded(*idx), + Unbounded => Unbounded, + }; + let end = match range.end_bound() { + Included(idx) => Included(*idx), + Excluded(idx) => Excluded(*idx), + Unbounded => Unbounded, + }; + TensorIndexer::Narrow(start, end) + } +} + +/// Trait used to implement multiple signatures for ease of use of the slicing +/// of a tensor +pub trait IndexOp { + /// Returns a slicing iterator which are the chunks of data necessary to + /// reconstruct the desired tensor. + fn i(&self, index: T) -> Result; +} + +impl IndexOp for OpTensor +where + T: Into, +{ + fn i(&self, index: T) -> Result { + self.index(&[index.into()]) + } +} + +impl IndexOp<(A,)> for OpTensor +where + A: Into, +{ + fn i(&self, (a,): (A,)) -> Result { + self.index(&[a.into()]) + } +} + +macro_rules! index_op_tuple { + ($($t:ident),+) => { + #[allow(non_snake_case)] + impl<$($t),*> IndexOp<($($t,)*)> for OpTensor + where + $($t: Into,)* + { + fn i(&self, ($($t,)*): ($($t,)*)) -> Result { + self.index(&[$($t.into(),)*]) + } + } + }; +} + +index_op_tuple!(A, B, C); +index_op_tuple!(A, B, C, D); +index_op_tuple!(A, B, C, D, E); +index_op_tuple!(A, B, C, D, E, F); +index_op_tuple!(A, B, C, D, E, F, G); diff --git a/crates/piston-core/src/lib.rs b/crates/piston-core/src/lib.rs index cccba122..75b1793a 100644 --- a/crates/piston-core/src/lib.rs +++ b/crates/piston-core/src/lib.rs @@ -7,6 +7,7 @@ mod dtype; mod enforcer; mod executable; mod gpu; +mod indexer; mod ndarray_ext; mod op; mod ops; @@ -28,6 +29,7 @@ pub use dtype::*; pub use enforcer::*; pub use executable::*; pub use gpu::*; +pub use indexer::*; pub use ndarray_ext::*; pub use op::*; pub use ops::*;