2727# nice ruby interfaces for ATLAS functions.
2828#++
2929
30- require ' nmatrix/nmatrix.rb'
31- # need to have nmatrix required first or else bad things will happen
32- require_relative ' lapack_ext_common'
30+ require " nmatrix/nmatrix.rb"
31+ # need to have nmatrix required first or else bad things will happen
32+ require_relative " lapack_ext_common"
3333
3434NMatrix . register_lapack_extension ( "nmatrix-atlas" )
3535
3636require "nmatrix_atlas.so"
3737
3838class NMatrix
39-
40- #Add functions from the ATLAS C extension to the main LAPACK and BLAS modules.
41- #This will overwrite the original functions where applicable.
39+ # Add functions from the ATLAS C extension to the main LAPACK and BLAS modules.
40+ # This will overwrite the original functions where applicable.
4241 module LAPACK
4342 class << self
4443 NMatrix ::ATLAS ::LAPACK . singleton_methods . each do |m |
@@ -68,8 +67,8 @@ def posv(uplo, a, b)
6867 unless a . stype == :dense && b . stype == :dense
6968
7069 raise ( DataTypeError , "only works for non-integer, non-object dtypes" ) \
71- if a . integer_dtype? || a . object_dtype? || \
72- b . integer_dtype? || b . object_dtype?
70+ if a . integer_dtype? || a . object_dtype? || \
71+ b . integer_dtype? || b . object_dtype?
7372
7473 x = b . clone
7574 clone = a . clone
@@ -83,78 +82,81 @@ def posv(uplo, a, b)
8382 x . transpose
8483 end
8584
86- def geev ( matrix , which = :both )
85+ def geev ( matrix , which = :both )
8786 raise ( StorageTypeError , "LAPACK functions only work on dense matrices" ) \
8887 unless matrix . dense?
8988
9089 raise ( ShapeError , "eigenvalues can only be computed for square matrices" ) \
9190 unless matrix . dim == 2 && matrix . shape [ 0 ] == matrix . shape [ 1 ]
9291
93- jobvl = ( which == :both || which == :left ) ? :t : false
94- jobvr = ( which == :both || which == :right ) ? :t : false
92+ jobvl = which == :both || which == :left ? :t : false
93+ jobvr = which == :both || which == :right ? :t : false
9594
9695 n = matrix . shape [ 0 ]
9796
9897 # Outputs
9998 eigenvalues = NMatrix . new ( [ n , 1 ] , dtype : matrix . dtype )
100- # For real dtypes this holds only the real part of the eigenvalues.
99+ # For real dtypes this holds only the real part of the eigenvalues.
101100 imag_eigenvalues = matrix . complex_dtype? ? nil : NMatrix . new ( [ n , 1 ] , \
102- dtype : matrix . dtype ) # For complex dtypes, this is unused.
101+ dtype : matrix . dtype ) # For complex dtypes, this is unused.
103102 left_output = jobvl ? matrix . clone_structure : nil
104103 right_output = jobvr ? matrix . clone_structure : nil
105104
106105 # lapack_geev is a pure LAPACK routine so it expects column-major matrices,
107106 # so we need to transpose the input as well as the output.
108107 temporary_matrix = matrix . transpose
109- NMatrix ::LAPACK :: lapack_geev ( jobvl , # compute left eigenvectors of A?
110- jobvr , # compute right eigenvectors of A? (left eigenvectors of A**T)
111- n , # order of the matrix
112- temporary_matrix , # input matrix (used as work)
113- n , # leading dimension of matrix
114- eigenvalues , # real part of computed eigenvalues
115- imag_eigenvalues , # imag part of computed eigenvalues
116- left_output , # left eigenvectors, if applicable
117- n , # leading dimension of left_output
118- right_output , # right eigenvectors, if applicable
119- n , # leading dimension of right_output
120- 2 * n )
108+ NMatrix ::LAPACK . lapack_geev ( jobvl , # compute left eigenvectors of A?
109+ jobvr , # compute right eigenvectors of A? (left eigenvectors of A**T)
110+ n , # order of the matrix
111+ temporary_matrix , # input matrix (used as work)
112+ n , # leading dimension of matrix
113+ eigenvalues , # real part of computed eigenvalues
114+ imag_eigenvalues , # imag part of computed eigenvalues
115+ left_output , # left eigenvectors, if applicable
116+ n , # leading dimension of left_output
117+ right_output , # right eigenvectors, if applicable
118+ n , # leading dimension of right_output
119+ 2 * n )
121120 left_output = left_output . transpose if jobvl
122121 right_output = right_output . transpose if jobvr
123122
124-
125123 # For real dtypes, transform left_output and right_output into correct forms.
126124 # If the j'th and the (j+1)'th eigenvalues form a complex conjugate
127125 # pair, then the j'th and (j+1)'th columns of the matrix are
128126 # the real and imag parts of the eigenvector corresponding
129127 # to the j'th eigenvalue.
130- if ! matrix . complex_dtype?
128+ unless matrix . complex_dtype?
131129 complex_indices = [ ]
132130 n . times do |i |
133131 complex_indices << i if imag_eigenvalues [ i ] != 0.0
134132 end
135133
136- if ! complex_indices . empty?
134+ unless complex_indices . empty?
137135 # For real dtypes, put the real and imaginary parts together
138- eigenvalues = eigenvalues + imag_eigenvalues * \
139- Complex ( 0.0 , 1.0 )
140- left_output = left_output . cast ( dtype : \
141- NMatrix . upcast ( :complex64 , matrix . dtype ) ) if left_output
142- right_output = right_output . cast ( dtype : NMatrix . upcast ( :complex64 , \
143- matrix . dtype ) ) if right_output
136+ eigenvalues += imag_eigenvalues * \
137+ Complex ( 0.0 , 1.0 )
138+ if left_output
139+ left_output = left_output . cast ( dtype : \
140+ NMatrix . upcast ( :complex64 , matrix . dtype ) )
141+ end
142+ if right_output
143+ right_output = right_output . cast ( dtype : NMatrix . upcast ( :complex64 , \
144+ matrix . dtype ) )
145+ end
144146 end
145147
146148 complex_indices . each_slice ( 2 ) do |i , _ |
147149 if right_output
148- right_output [ 0 ...n , i ] = right_output [ 0 ...n , i ] + \
149- right_output [ 0 ...n , i + 1 ] * Complex ( 0.0 , 1.0 )
150- right_output [ 0 ...n , i + 1 ] = \
151- right_output [ 0 ...n , i ] . complex_conjugate
150+ right_output [ 0 ...n , i ] = right_output [ 0 ...n , i ] + \
151+ right_output [ 0 ...n , i + 1 ] * Complex ( 0.0 , 1.0 )
152+ right_output [ 0 ...n , i + 1 ] = \
153+ right_output [ 0 ...n , i ] . complex_conjugate
152154 end
153155
154156 if left_output
155- left_output [ 0 ...n , i ] = left_output [ 0 ...n , i ] + \
156- left_output [ 0 ...n , i + 1 ] * Complex ( 0.0 , 1.0 )
157- left_output [ 0 ...n , i + 1 ] = left_output [ 0 ...n , i ] . complex_conjugate
157+ left_output [ 0 ...n , i ] = left_output [ 0 ...n , i ] + \
158+ left_output [ 0 ...n , i + 1 ] * Complex ( 0.0 , 1.0 )
159+ left_output [ 0 ...n , i + 1 ] = left_output [ 0 ...n , i ] . complex_conjugate
158160 end
159161 end
160162 end
@@ -168,7 +170,7 @@ def geev(matrix, which=:both)
168170 end
169171 end
170172
171- def gesvd ( matrix , workspace_size = 1 )
173+ def gesvd ( matrix , workspace_size = 1 )
172174 result = alloc_svd_result ( matrix )
173175
174176 m = matrix . shape [ 0 ]
@@ -177,16 +179,16 @@ def gesvd(matrix, workspace_size=1)
177179 # This is a pure LAPACK function so it expects column-major functions.
178180 # So we need to transpose the input as well as the output.
179181 matrix = matrix . transpose
180- NMatrix ::LAPACK :: lapack_gesvd ( :a , :a , m , n , matrix , \
181- m , result [ 1 ] , result [ 0 ] , m , result [ 2 ] , n , workspace_size )
182+ NMatrix ::LAPACK . lapack_gesvd ( :a , :a , m , n , matrix , \
183+ m , result [ 1 ] , result [ 0 ] , m , result [ 2 ] , n , workspace_size )
182184 result [ 0 ] = result [ 0 ] . transpose
183185 result [ 2 ] = result [ 2 ] . transpose
184186 result
185187 end
186188
187- def gesdd ( matrix , workspace_size = nil )
189+ def gesdd ( matrix , workspace_size = nil )
188190 min_workspace_size = matrix . shape . min * \
189- ( 6 + 4 * matrix . shape . min ) + matrix . shape . max
191+ ( 6 + 4 * matrix . shape . min ) + matrix . shape . max
190192 workspace_size = min_workspace_size if \
191193 workspace_size . nil? || workspace_size < min_workspace_size
192194
@@ -198,8 +200,8 @@ def gesdd(matrix, workspace_size=nil)
198200 # This is a pure LAPACK function so it expects column-major functions.
199201 # So we need to transpose the input as well as the output.
200202 matrix = matrix . transpose
201- NMatrix ::LAPACK :: lapack_gesdd ( :a , m , n , matrix , m , result [ 1 ] , \
202- result [ 0 ] , m , result [ 2 ] , n , workspace_size )
203+ NMatrix ::LAPACK . lapack_gesdd ( :a , m , n , matrix , m , result [ 1 ] , \
204+ result [ 0 ] , m , result [ 2 ] , n , workspace_size )
203205 result [ 0 ] = result [ 0 ] . transpose
204206 result [ 2 ] = result [ 2 ] . transpose
205207 result
@@ -209,36 +211,36 @@ def gesdd(matrix, workspace_size=nil)
209211
210212 def invert!
211213 raise ( StorageTypeError , "invert only works on dense matrices currently" ) \
212- unless self . dense?
214+ unless dense?
213215
214216 raise ( ShapeError , "Cannot invert non-square matrix" ) \
215217 unless shape [ 0 ] == shape [ 1 ]
216218
217219 raise ( DataTypeError , "Cannot invert an integer matrix in-place" ) \
218- if self . integer_dtype?
220+ if integer_dtype?
219221
220222 # Even though we are using the ATLAS plugin, we still might be missing
221223 # CLAPACK (and thus clapack_getri) if we are on OS X.
222224 if NMatrix . has_clapack?
223225 # Get the pivot array; factor the matrix
224226 # We can't used getrf! here since it doesn't have the clapack behavior,
225227 # so it doesn't play nicely with clapack_getri
226- n = self . shape [ 0 ]
227- pivot = NMatrix ::LAPACK :: clapack_getrf ( :row , n , n , self , n )
228+ n = shape [ 0 ]
229+ pivot = NMatrix ::LAPACK . clapack_getrf ( :row , n , n , self , n )
228230 # Now calculate the inverse using the pivot array
229- NMatrix ::LAPACK :: clapack_getri ( :row , n , self , n , pivot )
231+ NMatrix ::LAPACK . clapack_getri ( :row , n , self , n , pivot )
230232 self
231233 else
232- __inverse__ ( self , true )
234+ __inverse__ ( self , true )
233235 end
234236 end
235237
236238 def potrf! ( which )
237239 raise ( StorageTypeError , "ATLAS functions only work on dense matrices" ) \
238- unless self . dense?
240+ unless dense?
239241 raise ( ShapeError , "Cholesky decomposition only valid for square matrices" ) \
240- unless self . dim == 2 && self . shape [ 0 ] == self . shape [ 1 ]
242+ unless dim == 2 && shape [ 0 ] == shape [ 1 ]
241243
242- NMatrix ::LAPACK :: clapack_potrf ( :row , which , self . shape [ 0 ] , self , self . shape [ 1 ] )
244+ NMatrix ::LAPACK . clapack_potrf ( :row , which , shape [ 0 ] , self , shape [ 1 ] )
243245 end
244246end
0 commit comments