@@ -16,22 +16,57 @@ mod windows;
1616use std:: iter:: FromIterator ;
1717use std:: marker:: PhantomData ;
1818use std:: ptr;
19+ use std:: slice:: { self , Iter as SliceIter , IterMut as SliceIterMut } ;
1920use alloc:: vec:: Vec ;
2021
22+ use crate :: imp_prelude:: * ;
2123use crate :: Ix1 ;
2224
23- use super :: { ArrayBase , ArrayView , ArrayViewMut , Axis , Data , NdProducer , RemoveAxis } ;
24- use super :: { Dimension , Ix , Ixs } ;
25+ use super :: { NdProducer , RemoveAxis } ;
2526
2627pub use self :: chunks:: { ExactChunks , ExactChunksIter , ExactChunksIterMut , ExactChunksMut } ;
2728pub use self :: lanes:: { Lanes , LanesMut } ;
2829pub use self :: windows:: Windows ;
2930
30- use std:: slice:: { self , Iter as SliceIter , IterMut as SliceIterMut } ;
31+ use crate :: dimension;
32+
33+ /// No traversal optmizations that would change element order or axis dimensions are permitted.
34+ ///
35+ /// This option is suitable for example for the indexed iterator.
36+ pub ( crate ) enum NoOptimization { }
37+
38+ /// Preserve element iteration order, but modify dimensions if profitable; for example we can
39+ /// change from shape [10, 1] to [1, 10], because that axis has len == 1, without consequence here.
40+ ///
41+ /// This option is suitable for example for the default .iter() iterator.
42+ pub ( crate ) enum PreserveOrder { }
43+
44+ /// Allow use of arbitrary element iteration order
45+ ///
46+ /// This option is suitable for example for an arbitrary order iterator.
47+ pub ( crate ) enum ArbitraryOrder { }
48+
49+ pub ( crate ) trait OrderOption {
50+ const ALLOW_REMOVE_REDUNDANT_AXES : bool = false ;
51+ const ALLOW_ARBITRARY_ORDER : bool = false ;
52+ }
53+
54+ impl OrderOption for NoOptimization { }
55+
56+ impl OrderOption for PreserveOrder {
57+ const ALLOW_REMOVE_REDUNDANT_AXES : bool = true ;
58+ }
59+
60+ impl OrderOption for ArbitraryOrder {
61+ const ALLOW_REMOVE_REDUNDANT_AXES : bool = true ;
62+ const ALLOW_ARBITRARY_ORDER : bool = true ;
63+ }
3164
3265/// Base for iterators over all axes.
3366///
3467/// Iterator element type is `*mut A`.
68+ ///
69+ /// `F` is for layout/iteration order flags
3570pub ( crate ) struct Baseiter < A , D > {
3671 ptr : * mut A ,
3772 dim : D ,
@@ -44,12 +79,43 @@ impl<A, D: Dimension> Baseiter<A, D> {
4479 /// to be correct to avoid performing an unsafe pointer offset while
4580 /// iterating.
4681 #[ inline]
47- pub unsafe fn new ( ptr : * mut A , len : D , stride : D ) -> Baseiter < A , D > {
82+ pub unsafe fn new ( ptr : * mut A , dim : D , strides : D ) -> Baseiter < A , D > {
83+ Self :: new_with_order :: < NoOptimization > ( ptr, dim, strides)
84+ }
85+ }
86+
87+ impl < A , D : Dimension > Baseiter < A , D > {
88+ /// Creating a Baseiter is unsafe because shape and stride parameters need
89+ /// to be correct to avoid performing an unsafe pointer offset while
90+ /// iterating.
91+ #[ inline]
92+ pub unsafe fn new_with_order < Flags : OrderOption > ( mut ptr : * mut A , mut dim : D , mut strides : D )
93+ -> Baseiter < A , D >
94+ {
95+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
96+ if Flags :: ALLOW_ARBITRARY_ORDER {
97+ // iterate in memory order; merge axes if possible
98+ // make all axes positive and put the pointer back to the first element in memory
99+ let offset = dimension:: offset_from_ptr_to_memory ( & dim, & strides) ;
100+ ptr = ptr. offset ( offset) ;
101+ for i in 0 ..strides. ndim ( ) {
102+ let s = strides. get_stride ( Axis ( i) ) ;
103+ if s < 0 {
104+ strides. set_stride ( Axis ( i) , -s) ;
105+ }
106+ }
107+ dimension:: sort_axes_to_standard ( & mut dim, & mut strides) ;
108+ }
109+ if Flags :: ALLOW_REMOVE_REDUNDANT_AXES {
110+ // preserve element order but shift dimensions
111+ dimension:: merge_axes_from_the_back ( & mut dim, & mut strides) ;
112+ dimension:: squeeze ( & mut dim, & mut strides) ;
113+ }
48114 Baseiter {
49115 ptr,
50- index : len . first_index ( ) ,
51- dim : len ,
52- strides : stride ,
116+ index : dim . first_index ( ) ,
117+ dim,
118+ strides,
53119 }
54120 }
55121}
@@ -1496,3 +1562,147 @@ where
14961562 debug_assert_eq ! ( size, result. len( ) ) ;
14971563 result
14981564}
1565+
1566+ #[ cfg( test) ]
1567+ #[ cfg( feature = "std" ) ]
1568+ mod tests {
1569+ use crate :: prelude:: * ;
1570+ use super :: Baseiter ;
1571+ use super :: { ArbitraryOrder , PreserveOrder , NoOptimization } ;
1572+ use itertools:: assert_equal;
1573+ use itertools:: Itertools ;
1574+
1575+ // 3-d axis swaps
1576+ fn swaps ( ) -> impl Iterator < Item =Vec < ( usize , usize ) > > {
1577+ vec ! [
1578+ vec![ ] ,
1579+ vec![ ( 0 , 1 ) ] ,
1580+ vec![ ( 0 , 2 ) ] ,
1581+ vec![ ( 1 , 2 ) ] ,
1582+ vec![ ( 0 , 1 ) , ( 1 , 2 ) ] ,
1583+ vec![ ( 0 , 1 ) , ( 0 , 2 ) ] ,
1584+ ] . into_iter ( )
1585+ }
1586+
1587+ // 3-d axis inverts
1588+ fn inverts ( ) -> impl Iterator < Item =Vec < Axis > > {
1589+ vec ! [
1590+ vec![ ] ,
1591+ vec![ Axis ( 0 ) ] ,
1592+ vec![ Axis ( 1 ) ] ,
1593+ vec![ Axis ( 2 ) ] ,
1594+ vec![ Axis ( 0 ) , Axis ( 1 ) ] ,
1595+ vec![ Axis ( 0 ) , Axis ( 2 ) ] ,
1596+ vec![ Axis ( 1 ) , Axis ( 2 ) ] ,
1597+ vec![ Axis ( 0 ) , Axis ( 1 ) , Axis ( 2 ) ] ,
1598+ ] . into_iter ( )
1599+ }
1600+
1601+ #[ test]
1602+ fn test_arbitrary_order ( ) {
1603+ for swap in swaps ( ) {
1604+ for invert in inverts ( ) {
1605+ for & slice in & [ false , true ] {
1606+ // pattern is 0, 1; 4, 5; 8, 9; etc..
1607+ let mut a = Array :: from_iter ( 0 ..24 ) . into_shape ( ( 3 , 4 , 2 ) ) . unwrap ( ) ;
1608+ if slice {
1609+ a. slice_collapse ( s ! [ .., ..; 2 , ..] ) ;
1610+ }
1611+ for & ( i, j) in & swap {
1612+ a. swap_axes ( i, j) ;
1613+ }
1614+ for & i in & invert {
1615+ a. invert_axis ( i) ;
1616+ }
1617+ unsafe {
1618+ // Should have in-memory order for arbitrary order
1619+ let iter = Baseiter :: new_with_order :: < ArbitraryOrder > ( a. as_mut_ptr ( ) ,
1620+ a. dim , a. strides ) ;
1621+ if !slice {
1622+ assert_equal ( iter. map ( |ptr| * ptr) , 0 ..a. len ( ) ) ;
1623+ } else {
1624+ assert_eq ! ( iter. map( |ptr| * ptr) . collect_vec( ) ,
1625+ ( 0 ..a. len( ) * 2 ) . filter( |& x| ( x / 2 ) % 2 == 0 ) . collect_vec( ) ) ;
1626+ }
1627+ }
1628+ }
1629+ }
1630+ }
1631+ }
1632+
1633+ #[ test]
1634+ fn test_logical_order ( ) {
1635+ for swap in swaps ( ) {
1636+ for invert in inverts ( ) {
1637+ for & slice in & [ false , true ] {
1638+ let mut a = Array :: from_iter ( 0 ..24 ) . into_shape ( ( 3 , 4 , 2 ) ) . unwrap ( ) ;
1639+ for & ( i, j) in & swap {
1640+ a. swap_axes ( i, j) ;
1641+ }
1642+ for & i in & invert {
1643+ a. invert_axis ( i) ;
1644+ }
1645+ if slice {
1646+ a. slice_collapse ( s ! [ .., ..; 2 , ..] ) ;
1647+ }
1648+
1649+ unsafe {
1650+ let mut iter = Baseiter :: new_with_order :: < NoOptimization > ( a. as_mut_ptr ( ) ,
1651+ a. dim , a. strides ) ;
1652+ let mut index = Dim ( [ 0 , 0 , 0 ] ) ;
1653+ let mut elts = 0 ;
1654+ while let Some ( elt) = iter. next ( ) {
1655+ assert_eq ! ( * elt, a[ index] ) ;
1656+ if let Some ( index_) = a. raw_dim ( ) . next_for ( index) {
1657+ index = index_;
1658+ }
1659+ elts += 1 ;
1660+ }
1661+ assert_eq ! ( elts, a. len( ) ) ;
1662+ }
1663+ }
1664+ }
1665+ }
1666+ }
1667+
1668+ #[ test]
1669+ fn test_preserve_order ( ) {
1670+ for swap in swaps ( ) {
1671+ for invert in inverts ( ) {
1672+ for & slice in & [ false , true ] {
1673+ let mut a = Array :: from_iter ( 0 ..20 ) . into_shape ( ( 2 , 10 , 1 ) ) . unwrap ( ) ;
1674+ for & ( i, j) in & swap {
1675+ a. swap_axes ( i, j) ;
1676+ }
1677+ for & i in & invert {
1678+ a. invert_axis ( i) ;
1679+ }
1680+ if slice {
1681+ a. slice_collapse ( s ! [ .., ..; 2 , ..] ) ;
1682+ }
1683+
1684+ unsafe {
1685+ let mut iter = Baseiter :: new_with_order :: < PreserveOrder > (
1686+ a. as_mut_ptr ( ) , a. dim , a. strides ) ;
1687+
1688+ // check that axes have been merged (when it's easy to check)
1689+ if a. shape ( ) == & [ 2 , 10 , 1 ] && invert. is_empty ( ) {
1690+ assert_eq ! ( iter. dim, Dim ( [ 1 , 1 , 20 ] ) ) ;
1691+ }
1692+
1693+ let mut index = Dim ( [ 0 , 0 , 0 ] ) ;
1694+ let mut elts = 0 ;
1695+ while let Some ( elt) = iter. next ( ) {
1696+ assert_eq ! ( * elt, a[ index] ) ;
1697+ if let Some ( index_) = a. raw_dim ( ) . next_for ( index) {
1698+ index = index_;
1699+ }
1700+ elts += 1 ;
1701+ }
1702+ assert_eq ! ( elts, a. len( ) ) ;
1703+ }
1704+ }
1705+ }
1706+ }
1707+ }
1708+ }
0 commit comments