33__all__ = []
44
55
6- from numpy import asanyarray , asarray , array , zeros
6+ import numpy as np
77
88from scipy .sparse .linalg ._interface import aslinearoperator , LinearOperator , \
99 IdentityOperator
@@ -35,7 +35,7 @@ def make_system(A, M, x0, b):
3535 ----------
3636 A : LinearOperator
3737 sparse or dense matrix (or any valid input to aslinearoperator)
38- M : {LinearOperator, Nones }
38+ M : {LinearOperator, None }
3939 preconditioner
4040 sparse or dense matrix (or any valid input to aslinearoperator)
4141 x0 : {array_like, str, None}
@@ -61,14 +61,15 @@ def make_system(A, M, x0, b):
6161 A_ = A
6262 A = aslinearoperator (A )
6363
64- if A .shape [0 ] != A .shape [1 ]:
65- raise ValueError (f'expected square matrix, but got shape={ (A .shape ,)} ' )
64+ if ( N := A .shape [- 2 ]) != A .shape [- 1 ]:
65+ raise ValueError (f'expected square matrix or stack of square matrices , but got shape={ (A .shape ,)} ' )
6666
67- N = A . shape [ 0 ]
67+ b = np . asanyarray ( b )
6868
69- b = asanyarray (b )
69+ column_vector = b .ndim == 2 and b .shape [- 2 :] == (N , 1 ) # maintain column vector backwards-compatibility in 2-D case
70+ row_vector = b .shape [- 1 ] == N # otherwise treat as a row-vector
7071
71- if not (b . shape == ( N , 1 ) or b . shape == ( N ,) ):
72+ if not (column_vector or row_vector ):
7273 raise ValueError (f'shapes of A { A .shape } and b { b .shape } are '
7374 'incompatible' )
7475
@@ -81,8 +82,9 @@ def make_system(A, M, x0, b):
8182 xtype = A .matvec (b ).dtype .char
8283 xtype = coerce (xtype , b .dtype .char )
8384
84- b = asarray (b ,dtype = xtype ) # make b the same type as x
85- b = b .ravel ()
85+ b = np .asarray (b , dtype = xtype ) # make b the same type as x
86+ if column_vector :
87+ b = np .ravel (b )
8688
8789 # process preconditioner
8890 if M is None :
@@ -106,16 +108,21 @@ def make_system(A, M, x0, b):
106108
107109 # set initial guess
108110 if x0 is None :
109- x = zeros (N , dtype = xtype )
111+ x = np . zeros (( * M . shape [: - 2 ], N ) , dtype = xtype )
110112 elif isinstance (x0 , str ):
111113 if x0 == 'Mb' : # use nonzero initial guess ``M @ b``
112114 bCopy = b .copy ()
113115 x = M .matvec (bCopy )
114116 else :
115- x = array (x0 , dtype = xtype )
116- if not (x .shape == (N , 1 ) or x .shape == (N ,)):
117+ x = np .array (x0 , dtype = xtype )
118+
119+ column_vector = x .ndim == 2 and x .shape [- 2 :] == (N , 1 ) # maintain column vector backwards-compatibility in 2-D case
120+ row_vector = x .shape [- 1 ] == N # otherwise treat as a row-vector
121+
122+ if not (row_vector or column_vector ):
117123 raise ValueError (f'shapes of A { A .shape } and '
118124 f'x0 { x .shape } are incompatible' )
119- x = x .ravel ()
125+ if column_vector :
126+ x = np .ravel (x )
120127
121128 return A , M , x , b
0 commit comments