@@ -19,22 +19,59 @@ use alloc::vec::Vec;
1919use std:: iter:: FromIterator ;
2020use std:: marker:: PhantomData ;
2121use std:: ptr;
22+ use std:: slice:: { self , Iter as SliceIter , IterMut as SliceIterMut } ;
2223
23- use crate :: Ix1 ;
24+ use crate :: imp_prelude :: * ;
2425
25- use super :: { ArrayBase , ArrayView , ArrayViewMut , Axis , Data , NdProducer , RemoveAxis } ;
26- use super :: { Dimension , Ix , Ixs } ;
26+ use super :: NdProducer ;
2727
2828pub use self :: chunks:: { ExactChunks , ExactChunksIter , ExactChunksIterMut , ExactChunksMut } ;
2929pub use self :: into_iter:: IntoIter ;
3030pub use self :: lanes:: { Lanes , LanesMut } ;
3131pub use self :: windows:: Windows ;
3232
33- use std:: slice:: { self , Iter as SliceIter , IterMut as SliceIterMut } ;
33+ use crate :: dimension;
34+
35+ /// No traversal optmizations that would change element order or axis dimensions are permitted.
36+ ///
37+ /// This option is suitable for example for the indexed iterator.
38+ pub ( crate ) enum NoOptimization { }
39+
40+ /// Preserve element iteration order, but modify dimensions if profitable; for example we can
41+ /// change from shape [10, 1] to [1, 10], because that axis has len == 1, without consequence here.
42+ ///
43+ /// This option is suitable for example for the default .iter() iterator.
44+ pub ( crate ) enum PreserveOrder { }
45+
46+ /// Allow use of arbitrary element iteration order
47+ ///
48+ /// This option is suitable for example for an arbitrary order iterator.
49+ pub ( crate ) enum ArbitraryOrder { }
50+
51+ pub ( crate ) trait OrderOption
52+ {
53+ const ALLOW_REMOVE_REDUNDANT_AXES : bool = false ;
54+ const ALLOW_ARBITRARY_ORDER : bool = false ;
55+ }
56+
57+ impl OrderOption for NoOptimization { }
58+
59+ impl OrderOption for PreserveOrder
60+ {
61+ const ALLOW_REMOVE_REDUNDANT_AXES : bool = true ;
62+ }
63+
64+ impl OrderOption for ArbitraryOrder
65+ {
66+ const ALLOW_REMOVE_REDUNDANT_AXES : bool = true ;
67+ const ALLOW_ARBITRARY_ORDER : bool = true ;
68+ }
3469
3570/// Base for iterators over all axes.
3671///
3772/// Iterator element type is `*mut A`.
73+ ///
74+ /// `F` is for layout/iteration order flags
3875#[ derive( Debug ) ]
3976pub ( crate ) struct Baseiter < A , D >
4077{
@@ -50,13 +87,46 @@ impl<A, D: Dimension> Baseiter<A, D>
5087 /// to be correct to avoid performing an unsafe pointer offset while
5188 /// iterating.
5289 #[ inline]
53- pub unsafe fn new ( ptr : * mut A , len : D , stride : D ) -> Baseiter < A , D >
90+ pub unsafe fn new ( ptr : * mut A , dim : D , strides : D ) -> Baseiter < A , D >
5491 {
92+ Self :: new_with_order :: < NoOptimization > ( ptr, dim, strides)
93+ }
94+ }
95+
96+ impl < A , D : Dimension > Baseiter < A , D >
97+ {
98+ /// Creating a Baseiter is unsafe because shape and stride parameters need
99+ /// to be correct to avoid performing an unsafe pointer offset while
100+ /// iterating.
101+ #[ inline]
102+ pub unsafe fn new_with_order < Flags : OrderOption > ( mut ptr : * mut A , mut dim : D , mut strides : D ) -> Baseiter < A , D >
103+ {
104+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
105+ if Flags :: ALLOW_ARBITRARY_ORDER {
106+ // iterate in memory order; merge axes if possible
107+ // make all axes positive and put the pointer back to the first element in memory
108+ let offset = dimension:: offset_from_low_addr_ptr_to_logical_ptr ( & dim, & strides) ;
109+ ptr = ptr. sub ( offset) ;
110+ for i in 0 ..strides. ndim ( ) {
111+ let s = strides. get_stride ( Axis ( i) ) ;
112+ if s < 0 {
113+ strides. set_stride ( Axis ( i) , -s) ;
114+ }
115+ }
116+ dimension:: sort_axes_to_standard ( & mut dim, & mut strides) ;
117+ }
118+
119+ if Flags :: ALLOW_REMOVE_REDUNDANT_AXES {
120+ // preserve element order but shift dimensions
121+ dimension:: merge_axes_from_the_back ( & mut dim, & mut strides) ;
122+ dimension:: squeeze ( & mut dim, & mut strides) ;
123+ }
124+
55125 Baseiter {
56126 ptr,
57- index : len . first_index ( ) ,
58- dim : len ,
59- strides : stride ,
127+ index : dim . first_index ( ) ,
128+ dim,
129+ strides,
60130 }
61131 }
62132}
@@ -1585,3 +1655,152 @@ where
15851655 debug_assert_eq ! ( size, result. len( ) ) ;
15861656 result
15871657}
1658+
1659+ #[ cfg( test) ]
1660+ #[ cfg( feature = "std" ) ]
1661+ mod tests
1662+ {
1663+ use super :: Baseiter ;
1664+ use super :: { ArbitraryOrder , NoOptimization , PreserveOrder } ;
1665+ use crate :: prelude:: * ;
1666+ use itertools:: assert_equal;
1667+ use itertools:: Itertools ;
1668+
1669+ // 3-d axis swaps
1670+ fn swaps ( ) -> impl Iterator < Item = Vec < ( usize , usize ) > >
1671+ {
1672+ vec ! [
1673+ vec![ ] ,
1674+ vec![ ( 0 , 1 ) ] ,
1675+ vec![ ( 0 , 2 ) ] ,
1676+ vec![ ( 1 , 2 ) ] ,
1677+ vec![ ( 0 , 1 ) , ( 1 , 2 ) ] ,
1678+ vec![ ( 0 , 1 ) , ( 0 , 2 ) ] ,
1679+ ]
1680+ . into_iter ( )
1681+ }
1682+
1683+ // 3-d axis inverts
1684+ fn inverts ( ) -> impl Iterator < Item = Vec < Axis > >
1685+ {
1686+ vec ! [
1687+ vec![ ] ,
1688+ vec![ Axis ( 0 ) ] ,
1689+ vec![ Axis ( 1 ) ] ,
1690+ vec![ Axis ( 2 ) ] ,
1691+ vec![ Axis ( 0 ) , Axis ( 1 ) ] ,
1692+ vec![ Axis ( 0 ) , Axis ( 2 ) ] ,
1693+ vec![ Axis ( 1 ) , Axis ( 2 ) ] ,
1694+ vec![ Axis ( 0 ) , Axis ( 1 ) , Axis ( 2 ) ] ,
1695+ ]
1696+ . into_iter ( )
1697+ }
1698+
1699+ #[ test]
1700+ fn test_arbitrary_order ( )
1701+ {
1702+ for swap in swaps ( ) {
1703+ for invert in inverts ( ) {
1704+ for & slice in & [ false , true ] {
1705+ // pattern is 0, 1; 4, 5; 8, 9; etc..
1706+ let mut a = Array :: from_iter ( 0 ..24 ) . into_shape ( ( 3 , 4 , 2 ) ) . unwrap ( ) ;
1707+ if slice {
1708+ a. slice_collapse ( s ! [ .., ..; 2 , ..] ) ;
1709+ }
1710+ for & ( i, j) in & swap {
1711+ a. swap_axes ( i, j) ;
1712+ }
1713+ for & i in & invert {
1714+ a. invert_axis ( i) ;
1715+ }
1716+ unsafe {
1717+ // Should have in-memory order for arbitrary order
1718+ let iter = Baseiter :: new_with_order :: < ArbitraryOrder > ( a. as_mut_ptr ( ) , a. dim , a. strides ) ;
1719+ if !slice {
1720+ assert_equal ( iter. map ( |ptr| * ptr) , 0 ..a. len ( ) ) ;
1721+ } else {
1722+ assert_eq ! ( iter. map( |ptr| * ptr) . collect_vec( ) ,
1723+ ( 0 ..a. len( ) * 2 ) . filter( |& x| ( x / 2 ) % 2 == 0 ) . collect_vec( ) ) ;
1724+ }
1725+ }
1726+ }
1727+ }
1728+ }
1729+ }
1730+
1731+ #[ test]
1732+ fn test_logical_order ( )
1733+ {
1734+ for swap in swaps ( ) {
1735+ for invert in inverts ( ) {
1736+ for & slice in & [ false , true ] {
1737+ let mut a = Array :: from_iter ( 0 ..24 ) . into_shape ( ( 3 , 4 , 2 ) ) . unwrap ( ) ;
1738+ for & ( i, j) in & swap {
1739+ a. swap_axes ( i, j) ;
1740+ }
1741+ for & i in & invert {
1742+ a. invert_axis ( i) ;
1743+ }
1744+ if slice {
1745+ a. slice_collapse ( s ! [ .., ..; 2 , ..] ) ;
1746+ }
1747+
1748+ unsafe {
1749+ let mut iter = Baseiter :: new_with_order :: < NoOptimization > ( a. as_mut_ptr ( ) , a. dim , a. strides ) ;
1750+ let mut index = Dim ( [ 0 , 0 , 0 ] ) ;
1751+ let mut elts = 0 ;
1752+ while let Some ( elt) = iter. next ( ) {
1753+ assert_eq ! ( * elt, a[ index] ) ;
1754+ if let Some ( index_) = a. raw_dim ( ) . next_for ( index) {
1755+ index = index_;
1756+ }
1757+ elts += 1 ;
1758+ }
1759+ assert_eq ! ( elts, a. len( ) ) ;
1760+ }
1761+ }
1762+ }
1763+ }
1764+ }
1765+
1766+ #[ test]
1767+ fn test_preserve_order ( )
1768+ {
1769+ for swap in swaps ( ) {
1770+ for invert in inverts ( ) {
1771+ for & slice in & [ false , true ] {
1772+ let mut a = Array :: from_iter ( 0 ..20 ) . into_shape ( ( 2 , 10 , 1 ) ) . unwrap ( ) ;
1773+ for & ( i, j) in & swap {
1774+ a. swap_axes ( i, j) ;
1775+ }
1776+ for & i in & invert {
1777+ a. invert_axis ( i) ;
1778+ }
1779+ if slice {
1780+ a. slice_collapse ( s ! [ .., ..; 2 , ..] ) ;
1781+ }
1782+
1783+ unsafe {
1784+ let mut iter = Baseiter :: new_with_order :: < PreserveOrder > ( a. as_mut_ptr ( ) , a. dim , a. strides ) ;
1785+
1786+ // check that axes have been merged (when it's easy to check)
1787+ if a. shape ( ) == & [ 2 , 10 , 1 ] && invert. is_empty ( ) {
1788+ assert_eq ! ( iter. dim, Dim ( [ 1 , 1 , 20 ] ) ) ;
1789+ }
1790+
1791+ let mut index = Dim ( [ 0 , 0 , 0 ] ) ;
1792+ let mut elts = 0 ;
1793+ while let Some ( elt) = iter. next ( ) {
1794+ assert_eq ! ( * elt, a[ index] ) ;
1795+ if let Some ( index_) = a. raw_dim ( ) . next_for ( index) {
1796+ index = index_;
1797+ }
1798+ elts += 1 ;
1799+ }
1800+ assert_eq ! ( elts, a. len( ) ) ;
1801+ }
1802+ }
1803+ }
1804+ }
1805+ }
1806+ }
0 commit comments