Skip to content

Commit fe03769

Browse files
committed
WIP: batch sparse.linalg.cg
[skip ci]
1 parent 5fdcac3 commit fe03769

File tree

2 files changed

+31
-24
lines changed

2 files changed

+31
-24
lines changed

scipy/sparse/linalg/_isolve/iterative.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import warnings
2+
import functools
3+
24
import numpy as np
35
from scipy.sparse.linalg._interface import LinearOperator
46
from .utils import make_system
@@ -16,7 +18,7 @@ def _get_atol_rtol(name, b_norm, atol=0., rtol=1e-5):
1618
"if set, `atol` must be a real, non-negative number.")
1719
raise ValueError(msg)
1820

19-
atol = max(float(atol), float(rtol) * float(b_norm))
21+
atol = max(float(atol), float(rtol) * np.max(b_norm))
2022

2123
return atol, rtol
2224

@@ -377,19 +379,17 @@ def cg(A, b, x0=None, *, rtol=1e-5, atol=0., maxiter=None, M=None, callback=None
377379
True
378380
"""
379381
A, M, x, b = make_system(A, M, x0, b)
380-
bnrm2 = np.linalg.norm(b)
382+
bnrm2 = np.linalg.norm(b, axis=-1)
381383

382384
atol, _ = _get_atol_rtol('cg', bnrm2, atol, rtol)
383385

384-
if bnrm2 == 0:
386+
if not np.any(bnrm2):
385387
return b, 0
386388

387-
n = len(b)
388-
389389
if maxiter is None:
390-
maxiter = n*10
390+
maxiter = b.shape[-1] * 10
391391

392-
dotprod = np.vdot if np.iscomplexobj(x) else np.dot
392+
dotprod = np.vdot if np.iscomplexobj(x) else functools.partial(np.vecdot, axis=-1)
393393

394394
matvec = A.matvec
395395
psolve = M.matvec
@@ -399,23 +399,23 @@ def cg(A, b, x0=None, *, rtol=1e-5, atol=0., maxiter=None, M=None, callback=None
399399
rho_prev, p = None, None
400400

401401
for iteration in range(maxiter):
402-
if np.linalg.norm(r) < atol: # Are we done?
402+
if np.all(np.linalg.norm(r, axis=-1) < atol): # Are we done?
403403
return x, 0
404404

405405
z = psolve(r)
406406
rho_cur = dotprod(r, z)
407407
if iteration > 0:
408408
beta = rho_cur / rho_prev
409-
p *= beta
409+
p = (beta * p.T).T
410410
p += z
411411
else: # First spin
412412
p = np.empty_like(r)
413413
p[:] = z[:]
414414

415415
q = matvec(p)
416416
alpha = rho_cur / dotprod(p, q)
417-
x += alpha*p
418-
r -= alpha*q
417+
x += (alpha * p.T).T
418+
r -= (alpha * q.T).T
419419
rho_prev = rho_cur
420420

421421
if callback:

scipy/sparse/linalg/_isolve/utils.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
__all__ = []
44

55

6-
from numpy import asanyarray, asarray, array, zeros
6+
import numpy as np
77

88
from scipy.sparse.linalg._interface import aslinearoperator, LinearOperator, \
99
IdentityOperator
@@ -35,7 +35,7 @@ def make_system(A, M, x0, b):
3535
----------
3636
A : LinearOperator
3737
sparse or dense matrix (or any valid input to aslinearoperator)
38-
M : {LinearOperator, Nones}
38+
M : {LinearOperator, None}
3939
preconditioner
4040
sparse or dense matrix (or any valid input to aslinearoperator)
4141
x0 : {array_like, str, None}
@@ -61,14 +61,15 @@ def make_system(A, M, x0, b):
6161
A_ = A
6262
A = aslinearoperator(A)
6363

64-
if A.shape[0] != A.shape[1]:
65-
raise ValueError(f'expected square matrix, but got shape={(A.shape,)}')
64+
if (N := A.shape[-2]) != A.shape[-1]:
65+
raise ValueError(f'expected square matrix or stack of square matrices, but got shape={(A.shape,)}')
6666

67-
N = A.shape[0]
67+
b = np.asanyarray(b)
6868

69-
b = asanyarray(b)
69+
column_vector = b.ndim == 2 and b.shape[-2:] == (N, 1) # maintain column vector backwards-compatibility in 2-D case
70+
row_vector = b.shape[-1] == N # otherwise treat as a row-vector
7071

71-
if not (b.shape == (N,1) or b.shape == (N,)):
72+
if not (column_vector or row_vector):
7273
raise ValueError(f'shapes of A {A.shape} and b {b.shape} are '
7374
'incompatible')
7475

@@ -81,8 +82,9 @@ def make_system(A, M, x0, b):
8182
xtype = A.matvec(b).dtype.char
8283
xtype = coerce(xtype, b.dtype.char)
8384

84-
b = asarray(b,dtype=xtype) # make b the same type as x
85-
b = b.ravel()
85+
b = np.asarray(b, dtype=xtype) # make b the same type as x
86+
if column_vector:
87+
b = np.ravel(b)
8688

8789
# process preconditioner
8890
if M is None:
@@ -106,16 +108,21 @@ def make_system(A, M, x0, b):
106108

107109
# set initial guess
108110
if x0 is None:
109-
x = zeros(N, dtype=xtype)
111+
x = np.zeros((*M.shape[:-2], N), dtype=xtype)
110112
elif isinstance(x0, str):
111113
if x0 == 'Mb': # use nonzero initial guess ``M @ b``
112114
bCopy = b.copy()
113115
x = M.matvec(bCopy)
114116
else:
115-
x = array(x0, dtype=xtype)
116-
if not (x.shape == (N, 1) or x.shape == (N,)):
117+
x = np.array(x0, dtype=xtype)
118+
119+
column_vector = x.ndim == 2 and x.shape[-2:] == (N, 1) # maintain column vector backwards-compatibility in 2-D case
120+
row_vector = x.shape[-1] == N # otherwise treat as a row-vector
121+
122+
if not (row_vector or column_vector):
117123
raise ValueError(f'shapes of A {A.shape} and '
118124
f'x0 {x.shape} are incompatible')
119-
x = x.ravel()
125+
if column_vector:
126+
x = np.ravel(x)
120127

121128
return A, M, x, b

0 commit comments

Comments
 (0)