Skip to content

Commit 7ceafca

Browse files
committed
WIP: batch sparse.linalg.cg
[skip ci]
1 parent fe03769 commit 7ceafca

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

scipy/sparse/linalg/_isolve/iterative.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from .utils import make_system
77
from scipy.linalg import get_lapack_funcs
88

9+
from scipy._lib import array_api_extra as xpx
10+
911
__all__ = ['bicg', 'bicgstab', 'cg', 'cgs', 'gmres', 'qmr']
1012

1113

@@ -399,23 +401,40 @@ def cg(A, b, x0=None, *, rtol=1e-5, atol=0., maxiter=None, M=None, callback=None
399401
rho_prev, p = None, None
400402

401403
for iteration in range(maxiter):
402-
if np.all(np.linalg.norm(r, axis=-1) < atol): # Are we done?
404+
converged = np.linalg.norm(r, axis=-1) < atol
405+
if np.all(converged):
403406
return x, 0
404407

405408
z = psolve(r)
406409
rho_cur = dotprod(r, z)
410+
407411
if iteration > 0:
408-
beta = rho_cur / rho_prev
409-
p = (beta * p.T).T
412+
beta = xpx.apply_where(
413+
~converged,
414+
(rho_cur, rho_prev),
415+
lambda cur, prev: cur / prev,
416+
fill_value=0.0,
417+
xp=np
418+
)
419+
p *= beta
410420
p += z
411421
else: # First spin
412422
p = np.empty_like(r)
413423
p[:] = z[:]
414424

415425
q = matvec(p)
416-
alpha = rho_cur / dotprod(p, q)
417-
x += (alpha * p.T).T
418-
r -= (alpha * q.T).T
426+
c = dotprod(p, q)
427+
428+
alpha = xpx.apply_where(
429+
~converged,
430+
(rho_cur, c),
431+
lambda rc, c: rc / c,
432+
fill_value=0.0,
433+
xp=np
434+
)
435+
436+
x += alpha*p
437+
r -= alpha*q
419438
rho_prev = rho_cur
420439

421440
if callback:

0 commit comments

Comments
 (0)