@@ -139,6 +139,25 @@ impl Mpi {
139139 }
140140 }
141141
142+ /// Checks if an [`Mpi`] is less than the other in constant time.
143+ ///
144+ /// Will return [`Error::MpiBadInputData`] if the allocated length of the two input [`Mpi`]s is not the same.
145+ pub fn less_than_const_time ( & self , other : & Mpi ) -> Result < bool > {
146+ mpi_inner_less_than_const_time ( & self . inner , & other. inner )
147+ }
148+
149+ /// Compares an [`Mpi`] with the other in constant time.
150+ ///
151+ /// Will return [`Error::MpiBadInputData`] if the allocated length of the two input [`Mpi`]s is not the same.
152+ pub fn cmp_const_time ( & self , other : & Mpi ) -> Result < Ordering > {
153+ mpi_inner_cmp_const_time ( & self . inner , & other. inner )
154+ }
155+
156+ /// Checks equalness with the other in constant time.
157+ pub fn eq_const_time ( & self , other : & Mpi ) -> Result < bool > {
158+ mpi_inner_eq_const_time ( & self . inner , & other. inner )
159+ }
160+
142161 pub fn as_u32 ( & self ) -> Result < u32 > {
143162 if self . bit_length ( ) ? > 32 {
144163 // Not exactly correct but close enough
@@ -409,6 +428,35 @@ impl Mpi {
409428 }
410429}
411430
431+ pub ( super ) fn mpi_inner_eq_const_time ( x : & mpi , y : & mpi ) -> core:: prelude:: v1:: Result < bool , Error > {
432+ match mpi_inner_cmp_const_time ( x, y) {
433+ Ok ( order) => Ok ( order == Ordering :: Equal ) ,
434+ Err ( Error :: MpiBadInputData ) => Ok ( false ) ,
435+ Err ( e) => Err ( e) ,
436+ }
437+ }
438+
439+ fn mpi_inner_cmp_const_time ( x : & mpi , y : & mpi ) -> Result < Ordering > {
440+ let less = mpi_inner_less_than_const_time ( x, y) ;
441+ let more = mpi_inner_less_than_const_time ( y, x) ;
442+ match ( less, more) {
443+ ( Ok ( true ) , Ok ( false ) ) => Ok ( Ordering :: Less ) ,
444+ ( Ok ( false ) , Ok ( true ) ) => Ok ( Ordering :: Greater ) ,
445+ ( Ok ( false ) , Ok ( false ) ) => Ok ( Ordering :: Equal ) ,
446+ ( Ok ( true ) , Ok ( true ) ) => unreachable ! ( ) ,
447+ ( Err ( e) , _) => Err ( e) ,
448+ ( Ok ( _) , Err ( e) ) => Err ( e) ,
449+ }
450+ }
451+
452+ fn mpi_inner_less_than_const_time ( x : & mpi , y : & mpi ) -> Result < bool > {
453+ let mut r = 0 ;
454+ unsafe {
455+ mpi_lt_mpi_ct ( x, y, & mut r) . into_result ( ) ?;
456+ } ;
457+ Ok ( r == 1 )
458+ }
459+
412460impl Ord for Mpi {
413461 fn cmp ( & self , other : & Mpi ) -> Ordering {
414462 let r = unsafe { mpi_cmp_mpi ( & self . inner , & other. inner ) } ;
@@ -709,3 +757,53 @@ impl ShrAssign<usize> for Mpi {
709757// mbedtls_mpi_sub_abs
710758// mbedtls_mpi_mod_int
711759// mbedtls_mpi_gcd
760+
761+ #[ cfg( test) ]
762+ mod tests {
763+ use super :: * ;
764+
765+ #[ test]
766+ fn test_less_than_const_time ( ) {
767+ let mpi1 = Mpi :: new ( 10 ) . unwrap ( ) ;
768+ let mpi2 = Mpi :: new ( 20 ) . unwrap ( ) ;
769+
770+ assert_eq ! ( mpi1. less_than_const_time( & mpi2) , Ok ( true ) ) ;
771+
772+ assert_eq ! ( mpi1. less_than_const_time( & mpi1) , Ok ( false ) ) ;
773+
774+ assert_eq ! ( mpi2. less_than_const_time( & mpi1) , Ok ( false ) ) ;
775+
776+ // Check: function returns `Error::MpiBadInputData` if the allocated length of the two input Mpis is not the same.
777+ let mpi3 = Mpi :: from_binary ( & [
778+ 0xdd , 0xdd , 0xdd , 0xdd , 0xdd , 0xdd , 0xdd , 0xdd , 0xdd , 0xdd , 0xdd , 0xdd , 0xdd , 0xdd ,
779+ ] )
780+ . unwrap ( ) ;
781+ assert_eq ! ( mpi3. less_than_const_time( & mpi3) , Ok ( false ) ) ;
782+ assert_eq ! ( mpi2. less_than_const_time( & mpi3) , Err ( Error :: MpiBadInputData ) ) ;
783+ }
784+
785+ #[ test]
786+ fn test_cmp_const_time ( ) {
787+ let mpi1 = Mpi :: new ( 10 ) . unwrap ( ) ;
788+ let mpi2 = Mpi :: new ( 20 ) . unwrap ( ) ;
789+
790+ assert_eq ! ( mpi1. cmp_const_time( & mpi2) , Ok ( Ordering :: Less ) ) ;
791+
792+ let mpi3 = Mpi :: new ( 10 ) . unwrap ( ) ;
793+ assert_eq ! ( mpi1. cmp_const_time( & mpi3) , Ok ( Ordering :: Equal ) ) ;
794+
795+ let mpi4 = Mpi :: new ( 5 ) . unwrap ( ) ;
796+ assert_eq ! ( mpi1. cmp_const_time( & mpi4) , Ok ( Ordering :: Greater ) ) ;
797+ }
798+
799+ #[ test]
800+ fn test_eq_const_time ( ) {
801+ let mpi1 = Mpi :: new ( 10 ) . unwrap ( ) ;
802+ let mpi2 = Mpi :: new ( 10 ) . unwrap ( ) ;
803+
804+ assert_eq ! ( mpi1. eq_const_time( & mpi2) , Ok ( true ) ) ;
805+
806+ let mpi3 = Mpi :: new ( 20 ) . unwrap ( ) ;
807+ assert_eq ! ( mpi1. eq_const_time( & mpi3) , Ok ( false ) ) ;
808+ }
809+ }
0 commit comments