@@ -12,6 +12,7 @@ use crate::{
1212 linalg:: array_ext:: { RemoveAxisExt , ShapeExt } ,
1313} ;
1414
15+ /// Represents the output of an LU decomposition acting on a matrix.
1516#[ derive( Debug ) ]
1617pub struct LU < A , S >
1718where
@@ -22,13 +23,22 @@ where
2223 pub ( self ) pivots : Array1 < usize > ,
2324}
2425
26+ /// Represents types that represent matrices that are decomposable into lower-
27+ /// and upper-triangular parts.
2528pub trait LUDecomposable {
29+ /// The element type for the output matrices.
2630 type Elem : Scalar ;
27- type Repr : Data < Elem = Self :: Elem > ;
31+
32+ /// Type of owned data in the output matrices.
2833 type OwnedRepr : Data + RawData < Elem = Self :: Elem > ;
34+
35+ /// The output type for LU decompositions on this type.
2936 type Output : LUDecomposition < Self :: Elem , Self :: OwnedRepr > ;
37+
38+ /// The type used to represent errors in the decomposition.
3039 type Error ;
3140
41+ /// Performs an LU decomposition on the given type.
3242 fn lu ( & self ) -> Result < Self :: Output , Self :: Error > ;
3343}
3444
3848 A : Scalar ,
3949{
4050 type Elem = A ;
41- type Repr = S ;
4251 type OwnedRepr = OwnedRepr < A > ;
4352 type Error = QdkSimError ;
4453 type Output = LU < A , OwnedRepr < A > > ;
5160 QdkSimError :: CannotConvertElement ( "f64" . to_string ( ) , type_name :: < A :: Real > ( ) . to_string ( ) )
5261 } ) ?;
5362
54- let mut factors = & mut ( * self ) . to_owned ( ) ;
63+ let factors = & mut ( * self ) . to_owned ( ) ;
5564 let mut pivots: Array1 < _ > = ( 0 ..n_rows) . collect :: < Vec < _ > > ( ) . into ( ) ;
5665
5766 for j in 0 ..n_rows {
@@ -178,16 +187,84 @@ where
178187
179188#[ cfg( test) ]
180189mod tests {
181- use ndarray:: array;
190+ use approx:: assert_abs_diff_eq;
191+ use cauchy:: c64;
192+ use ndarray:: { array, Array2 , OwnedRepr } ;
182193
183- use crate :: { error:: QdkSimError , linalg:: decompositions:: LUDecomposable } ;
194+ use crate :: {
195+ error:: QdkSimError ,
196+ linalg:: decompositions:: { LUDecomposable , LU } ,
197+ } ;
184198
185199 #[ test]
186200 fn lu_decomposition_works_f64 ( ) -> Result < ( ) , QdkSimError > {
187201 let mtx = array ! [ [ 6.0 , 18.0 , 3.0 ] , [ 2.0 , 12.0 , 1.0 ] , [ 4.0 , 15.0 , 3.0 ] ] ;
188- // TODO: Actually write the test!
189- let lu = mtx. lu ( ) ?;
190- println ! ( "{:?}" , lu) ;
202+ let lu: LU < f64 , OwnedRepr < f64 > > = mtx. lu ( ) ?;
203+
204+ // NB: This tests the internal structure of the LU decomposition, and
205+ // may validly fail if the algorithm above is modified.
206+ let expected_factors = array ! [
207+ [ 6.0 , 18.0 , 3.0 ] ,
208+ [ 0.3333333333333333 , 6.0 , 0.0 ] ,
209+ [ 0.6666666666666666 , 0.5 , 1.0 ] ,
210+ ] ;
211+ for ( actual, expected) in lu. factors . iter ( ) . zip ( expected_factors. iter ( ) ) {
212+ assert_abs_diff_eq ! ( actual, expected, epsilon = 1e-6 ) ;
213+ }
214+
215+ let expected_pivots = vec ! [ 0 , 1 , 2 ] ;
216+ assert_eq ! ( lu. pivots. to_vec( ) , expected_pivots) ;
217+ Ok ( ( ) )
218+ }
219+
220+ #[ test]
221+ fn lu_decomposition_works_c64 ( ) -> Result < ( ) , QdkSimError > {
222+ // In [1]: import scipy.linalg as la
223+ // In [2]: la.lu([
224+ // ...: [-1, 1j, -2],
225+ // ...: [3, 0, -4j],
226+ // ...: [-1, 5, -1]
227+ // ...: ])
228+ // Out[2]: (array([[0., 0., 1.],
229+ // [1., 0., 0.],
230+ // [0., 1., 0.]]),
231+ // array([[ 1. +0.j , 0. +0.j , 0. +0.j ],
232+ // [-0.33333333+0.j , 1. +0.j , 0. +0.j ],
233+ // [-0.33333333+0.j , 0. +0.2j, 1. +0.j ]]),
234+ // array([[ 3. +0.j , 0. +0.j ,
235+ // -0. -4.j ],
236+ // [ 0. +0.j , 5. +0.j ,
237+ // -1. -1.33333333j],
238+ // [ 0. +0.j , 0. +0.j ,
239+ // -2.26666667-1.13333333j]]))
240+ let mtx: Array2 < c64 > = array ! [
241+ [ c64:: new( -1.0 , 0.0 ) , c64:: new( 0.0 , 1.0 ) , c64:: new( -2.0 , 0.0 ) ] ,
242+ [ c64:: new( 3.0 , 0.0 ) , c64:: new( 0.0 , 0.0 ) , c64:: new( 0.0 , -4.0 ) ] ,
243+ [ c64:: new( -1.0 , 0.0 ) , c64:: new( 5.0 , 0.0 ) , c64:: new( -1.0 , 0.0 ) ]
244+ ] ;
245+ let lu: LU < c64 , OwnedRepr < c64 > > = mtx. lu ( ) ?;
246+
247+ // NB: This tests the internal structure of the LU decomposition, and
248+ // may validly fail if the algorithm above is modified.
249+ let expected_factors = array ! [
250+ [ c64:: new( 3.0 , 0.0 ) , c64:: new( 0.0 , 0.0 ) , c64:: new( 0.0 , -4.0 ) ] ,
251+ [
252+ c64:: new( -0.3333333333333333 , 0.0 ) ,
253+ c64:: new( 5.0 , 0.0 ) ,
254+ c64:: new( -1.0 , -1.3333333333333333 )
255+ ] ,
256+ [
257+ c64:: new( -0.3333333333333333 , 0.0 ) ,
258+ c64:: new( 0.0 , 0.2 ) ,
259+ c64:: new( -2.26666667 , -1.13333333 )
260+ ] ,
261+ ] ;
262+ for ( actual, expected) in lu. factors . iter ( ) . zip ( expected_factors. iter ( ) ) {
263+ assert_abs_diff_eq ! ( actual, expected, epsilon = 1e-6 ) ;
264+ }
265+
266+ let expected_pivots = vec ! [ 1 , 2 , 2 ] ;
267+ assert_eq ! ( lu. pivots. to_vec( ) , expected_pivots) ;
191268 Ok ( ( ) )
192269 }
193270}
0 commit comments