@@ -785,6 +785,41 @@ where D: Dimension
785785 }
786786}
787787
788+ /// Remove axes with length one, except never removing the last axis.
789+ pub ( crate ) fn squeeze < D > ( dim : & mut D , strides : & mut D )
790+ where D : Dimension
791+ {
792+ if let Some ( _) = D :: NDIM {
793+ return ;
794+ }
795+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
796+
797+ // Count axes with dim == 1; we keep axes with d == 0 or d > 1
798+ let mut ndim_new = 0 ;
799+ for & d in dim. slice ( ) {
800+ if d != 1 {
801+ ndim_new += 1 ;
802+ }
803+ }
804+ ndim_new = Ord :: max ( 1 , ndim_new) ;
805+ let mut new_dim = D :: zeros ( ndim_new) ;
806+ let mut new_strides = D :: zeros ( ndim_new) ;
807+ let mut i = 0 ;
808+ for ( & d, & s) in izip ! ( dim. slice( ) , strides. slice( ) ) {
809+ if d != 1 {
810+ new_dim[ i] = d;
811+ new_strides[ i] = s;
812+ i += 1 ;
813+ }
814+ }
815+ if i == 0 {
816+ new_dim[ i] = 1 ;
817+ new_strides[ i] = 1 ;
818+ }
819+ * dim = new_dim;
820+ * strides = new_strides;
821+ }
822+
788823#[ cfg( test) ]
789824mod test
790825{
@@ -797,6 +832,7 @@ mod test
797832 slice_min_max,
798833 slices_intersect,
799834 solve_linear_diophantine_eq,
835+ squeeze,
800836 IntoDimension ,
801837 } ;
802838 use crate :: error:: { from_kind, ErrorKind } ;
@@ -1146,4 +1182,35 @@ mod test
11461182 s![ .., 3 ..; 6 , NewAxis ]
11471183 ) ) ;
11481184 }
1185+
1186+ #[ test]
1187+ #[ cfg( feature = "std" ) ]
1188+ fn test_squeeze ( )
1189+ {
1190+ let dyndim = Dim :: < & [ usize ] > ;
1191+
1192+ let mut d = dyndim ( & [ 1 , 2 , 1 , 1 , 3 , 1 ] ) ;
1193+ let mut s = dyndim ( & [ !0 , !0 , !0 , 9 , 10 , !0 ] ) ;
1194+ let dans = dyndim ( & [ 2 , 3 ] ) ;
1195+ let sans = dyndim ( & [ !0 , 10 ] ) ;
1196+ squeeze ( & mut d, & mut s) ;
1197+ assert_eq ! ( d, dans) ;
1198+ assert_eq ! ( s, sans) ;
1199+
1200+ let mut d = dyndim ( & [ 1 , 1 ] ) ;
1201+ let mut s = dyndim ( & [ 3 , 4 ] ) ;
1202+ let dans = dyndim ( & [ 1 ] ) ;
1203+ let sans = dyndim ( & [ 1 ] ) ;
1204+ squeeze ( & mut d, & mut s) ;
1205+ assert_eq ! ( d, dans) ;
1206+ assert_eq ! ( s, sans) ;
1207+
1208+ let mut d = dyndim ( & [ 0 , 1 , 3 , 4 ] ) ;
1209+ let mut s = dyndim ( & [ 2 , 3 , 4 , 5 ] ) ;
1210+ let dans = dyndim ( & [ 0 , 3 , 4 ] ) ;
1211+ let sans = dyndim ( & [ 2 , 4 , 5 ] ) ;
1212+ squeeze ( & mut d, & mut s) ;
1213+ assert_eq ! ( d, dans) ;
1214+ assert_eq ! ( s, sans) ;
1215+ }
11491216}
0 commit comments