@@ -762,6 +762,53 @@ where
762762    * strides = new_strides; 
763763} 
764764
765+ /// Remove axes with length one, except never removing the last axis. 
766+ pub ( crate )  fn  squeeze_into < D ,  E > ( dim :  & D ,  strides :  & D )  -> Result < ( E ,  E ) ,  ShapeError > 
767+ where 
768+     D :  Dimension , 
769+     E :  Dimension , 
770+ { 
771+     debug_assert_eq ! ( dim. ndim( ) ,  strides. ndim( ) ) ; 
772+ 
773+     // Count axes with dim == 1; we keep axes with d == 0 or d > 1 
774+     let  mut  ndim_new = 0 ; 
775+     for  & d in  dim. slice ( )  { 
776+         if  d != 1  {  ndim_new += 1 ;  } 
777+     } 
778+     let  mut  fill_ones = 0 ; 
779+     if  let  Some ( e_ndim)  = E :: NDIM  { 
780+         if  e_ndim < ndim_new { 
781+             return  Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ; 
782+         } 
783+         fill_ones = e_ndim - ndim_new; 
784+         ndim_new = e_ndim; 
785+     }  else  { 
786+         // dynamic-dimensional 
787+         // use minimum one dimension unless input has less than one dim 
788+         if  dim. ndim ( )  > 0  && ndim_new == 0  { 
789+             ndim_new = 1 ; 
790+             fill_ones = 1 ; 
791+         } 
792+     } 
793+ 
794+     let  mut  new_dim = E :: zeros ( ndim_new) ; 
795+     let  mut  new_strides = E :: zeros ( ndim_new) ; 
796+     let  mut  i = 0 ; 
797+     while  i < fill_ones { 
798+         new_dim[ i]  = 1 ; 
799+         new_strides[ i]  = 1 ; 
800+         i += 1 ; 
801+     } 
802+     for  ( & d,  & s)  in  izip ! ( dim. slice( ) ,  strides. slice( ) )  { 
803+         if  d != 1  { 
804+             new_dim[ i]  = d; 
805+             new_strides[ i]  = s; 
806+             i += 1 ; 
807+         } 
808+     } 
809+     Ok ( ( new_dim,  new_strides) ) 
810+ } 
811+ 
765812
766813/// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least 
767814/// stride 
@@ -1148,6 +1195,91 @@ mod test {
11481195        assert_eq ! ( s,  sans) ; 
11491196    } 
11501197
1198+     #[ test]  
1199+     #[ cfg( feature = "std" ) ]  
1200+     fn  test_squeeze_into ( )  { 
1201+         use  super :: squeeze_into; 
1202+ 
1203+         let  dyndim = Dim :: < & [ usize ] > ; 
1204+ 
1205+         // squeeze to ixdyn 
1206+         let  d = dyndim ( & [ 1 ,  2 ,  1 ,  1 ,  3 ,  1 ] ) ; 
1207+         let  s = dyndim ( & [ !0 ,  !0 ,  !0 ,  9 ,  10 ,  !0 ] ) ; 
1208+         let  dans = dyndim ( & [ 2 ,  3 ] ) ; 
1209+         let  sans = dyndim ( & [ !0 ,  10 ] ) ; 
1210+         let  ( d2,  s2)  = squeeze_into :: < _ ,  IxDyn > ( & d,  & s) . unwrap ( ) ; 
1211+         assert_eq ! ( d2,  dans) ; 
1212+         assert_eq ! ( s2,  sans) ; 
1213+ 
1214+         // squeeze to ixdyn does not go below 1D 
1215+         let  d = dyndim ( & [ 1 ,  1 ] ) ; 
1216+         let  s = dyndim ( & [ 3 ,  4 ] ) ; 
1217+         let  dans = dyndim ( & [ 1 ] ) ; 
1218+         let  sans = dyndim ( & [ 1 ] ) ; 
1219+         let  ( d2,  s2)  = squeeze_into :: < _ ,  IxDyn > ( & d,  & s) . unwrap ( ) ; 
1220+         assert_eq ! ( d2,  dans) ; 
1221+         assert_eq ! ( s2,  sans) ; 
1222+ 
1223+         let  d = Dim ( [ 1 ,  1 ] ) ; 
1224+         let  s = Dim ( [ 3 ,  4 ] ) ; 
1225+         let  dans = Dim ( [ 1 ] ) ; 
1226+         let  sans = Dim ( [ 1 ] ) ; 
1227+         let  ( d2,  s2)  = squeeze_into :: < _ ,  Ix1 > ( & d,  & s) . unwrap ( ) ; 
1228+         assert_eq ! ( d2,  dans) ; 
1229+         assert_eq ! ( s2,  sans) ; 
1230+ 
1231+         // squeeze to zero-dim 
1232+         let  ( d2,  s2)  = squeeze_into :: < _ ,  Ix0 > ( & d,  & s) . unwrap ( ) ; 
1233+         assert_eq ! ( d2,  Ix0 ( ) ) ; 
1234+         assert_eq ! ( s2,  Ix0 ( ) ) ; 
1235+ 
1236+         let  d = Dim ( [ 0 ,  1 ,  3 ,  4 ] ) ; 
1237+         let  s = Dim ( [ 2 ,  3 ,  4 ,  5 ] ) ; 
1238+         let  dans = Dim ( [ 0 ,  3 ,  4 ] ) ; 
1239+         let  sans = Dim ( [ 2 ,  4 ,  5 ] ) ; 
1240+         let  ( d2,  s2)  = squeeze_into :: < _ ,  Ix3 > ( & d,  & s) . unwrap ( ) ; 
1241+         assert_eq ! ( d2,  dans) ; 
1242+         assert_eq ! ( s2,  sans) ; 
1243+ 
1244+         // Pad with ones 
1245+         let  d = Dim ( [ 0 ,  1 ,  3 ,  1 ] ) ; 
1246+         let  s = Dim ( [ 2 ,  3 ,  4 ,  5 ] ) ; 
1247+         let  dans = Dim ( [ 1 ,  0 ,  3 ] ) ; 
1248+         let  sans = Dim ( [ 1 ,  2 ,  4 ] ) ; 
1249+         let  ( d2,  s2)  = squeeze_into :: < _ ,  Ix3 > ( & d,  & s) . unwrap ( ) ; 
1250+         assert_eq ! ( d2,  dans) ; 
1251+         assert_eq ! ( s2,  sans) ; 
1252+ 
1253+         // Try something that doesn't fit 
1254+         let  d = Dim ( [ 0 ,  1 ,  3 ,  1 ] ) ; 
1255+         let  s = Dim ( [ 2 ,  3 ,  4 ,  5 ] ) ; 
1256+         let  res = squeeze_into :: < _ ,  Ix1 > ( & d,  & s) ; 
1257+         assert ! ( res. is_err( ) ) ; 
1258+         let  res = squeeze_into :: < _ ,  Ix0 > ( & d,  & s) ; 
1259+         assert ! ( res. is_err( ) ) ; 
1260+ 
1261+         // Squeeze 0d to 0d 
1262+         let  d = Dim ( [ ] ) ; 
1263+         let  s = Dim ( [ ] ) ; 
1264+         let  res = squeeze_into :: < _ ,  Ix0 > ( & d,  & s) ; 
1265+         assert ! ( res. is_ok( ) ) ; 
1266+         // grow 0d to 2d 
1267+         let  dans = Dim ( [ 1 ,  1 ] ) ; 
1268+         let  sans = Dim ( [ 1 ,  1 ] ) ; 
1269+         let  ( d2,  s2)  = squeeze_into :: < _ ,  Ix2 > ( & d,  & s) . unwrap ( ) ; 
1270+         assert_eq ! ( d2,  dans) ; 
1271+         assert_eq ! ( s2,  sans) ; 
1272+ 
1273+         // Squeeze 0d to 0d dynamic 
1274+         let  d = dyndim ( & [ ] ) ; 
1275+         let  s = dyndim ( & [ ] ) ; 
1276+         let  ( d2,  s2)  = squeeze_into :: < _ ,  IxDyn > ( & d,  & s) . unwrap ( ) ; 
1277+         let  dans = d; 
1278+         let  sans = s; 
1279+         assert_eq ! ( d2,  dans) ; 
1280+         assert_eq ! ( s2,  sans) ; 
1281+     } 
1282+ 
11511283    #[ test]  
11521284    fn  test_merge_axes_from_the_back ( )  { 
11531285        let  dyndim = Dim :: < & [ usize ] > ; 
0 commit comments