Skip to content

Conversation

@lucascolley
Copy link
Owner

@lucascolley lucascolley commented Oct 16, 2025


return np.hstack([self.matvec(col.reshape(-1,1)) for col in X.T])
# X.mT here?
return np.hstack([self.matvec(col.reshape(-1, 1)) for col in X.T])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the new matvec can handle batched dims, do we need hstack or can it just be done all in one go?

@izaid
Copy link
Collaborator

izaid commented Oct 16, 2025

Yes! This looks good in general, basically what I was thinking. Will this error in all previous cases if someone passes an ND LinearOperator to an existing method. For instance, gmres should throw an error if the LinearOperator it receives has ndim != 2.

@lucascolley
Copy link
Owner Author

Will this error in all previous cases if someone passes an ND LinearOperator to an existing method.

that is also TODO 👍

@lucascolley lucascolley force-pushed the linearoperator-nd branch 4 times, most recently from 39c5712 to fe03769 Compare October 21, 2025 22:39
@lucascolley lucascolley changed the title wip: N-D LinearOperator wip: N-D LinearOperator and cg Oct 24, 2025
Comment on lines 19 to 23
atol = max(float(atol), float(rtol) * float(b_norm))
atol = max(float(atol), float(rtol) * np.max(b_norm))
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we support different atols for different systems?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not yet, may want to change this to min

Comment on lines 408 to 419
beta = rho_cur / rho_prev
p *= beta
beta = xpx.apply_where(
not converged,
(rho_cur, rho_prev),
lambda cur, prev: cur / prev,
fill_value=0.0,
xp=np
)
p = (beta * p.T).T
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@izaid this is the proposed implementation of 'masking' to match what the paper is doing for alpha and beta:

While waiting for the slowest solution to converge, all the previously converged solutions are not updated anymore; their corresponding steps α and β are set to zero. This stopped update allows to enforce that no overflow, underflow, or NaN is generated when dealing with non-zero vectors which have zero norm due to rounding errors.

Image

maxiter = b.shape[-1] * 10

dotprod = np.vdot if np.iscomplexobj(x) else np.dot
dotprod = np.vdot if np.iscomplexobj(x) else functools.partial(np.vecdot, axis=-1)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may need axis=-1 for the complex case too

@lucascolley lucascolley force-pushed the linearoperator-nd branch 2 times, most recently from 7ceafca to b6739f0 Compare November 6, 2025 23:24
@lucascolley lucascolley force-pushed the linearoperator-nd branch 2 times, most recently from 7f4e133 to 0569d5b Compare November 14, 2025 17:41
@lucascolley
Copy link
Owner Author

Prototype working with array_api_strict and jax.numpy! (no JIT yet)

In [1]: import array_api_strict as xp; from scipy.sparse.linalg import cg, LinearOperator; import numpy as np

In [2]: def solve(N, batch, report_index=0, batched=False):
   ...:     rng = np.random.default_rng(0)
   ...:     M = rng.standard_normal((N, N))
   ...:     M = xp.asarray(M)
   ...:     reg = 1e-3
   ...:
   ...:     if batched:
   ...:         M = xp.broadcast_to(M[xp.newaxis, ...], (batch, *M.shape))
   ...:
   ...:     def matvec(x):
   ...:         return xp.squeeze(M.mT @ (M @ x[..., xp.newaxis]), axis=-1) + reg * x
   ...:
   ...:     shape = (batch, N, N) if batched else (N, N)
   ...:     A = LinearOperator(shape, matvec=matvec, dtype=xp.float64, xp=xp)
   ...:
   ...:     b = rng.standard_normal(N)
   ...:     b = xp.asarray(b)
   ...:
   ...:     if batched:
   ...:         b = xp.reshape(xp.arange(batch, dtype=xp.float64), (batch, 1)) * b
   ...:         x, info = cg(A, b, atol=1e-8, maxiter=5000)
   ...:         assert info == 0
   ...:         print(f"{x[report_index, ...]}")
   ...:     else:
   ...:         for i in xp.arange(batch, dtype=xp.float64):
   ...:             x, info = cg(A, i*b, atol=1e-8, maxiter=5000)
   ...:             assert info == 0
   ...:             if i == report_index:
   ...:                 print(x)
   ...:

In [3]: solve(5, 10, report_index=7)
Array([ 10.91985197,  -5.53737923,
        -6.96397906, -35.6473016 ,
        13.48931722], dtype=array_api_strict.float64)

In [4]: solve(5, 10, report_index=7, batched=True)
Array([ 10.91985197,  -5.53737923,
        -6.96397906, -35.6473016 ,
        13.48931722], dtype=array_api_strict.float64)

In [5]: import jax
   ...: jax.config.update("jax_enable_x64", True)

In [6]: import jax.numpy as xp

In [7]: solve(5, 10, report_index=7, batched=True)
[ 10.91985197  -5.53737923  -6.96397906 -35.6473016   13.48931722]

For JIT I think we will need to do some fancy registration of the linear operator classes as PyTrees, along the lines of https://docs.jax.dev/en/latest/_autosummary/jax.tree_util.register_pytree_node.html#jax.tree_util.register_pytree_node.

@izaid
Copy link
Collaborator

izaid commented Nov 15, 2025

Oh that's great! Yes, might need to register LinearOperator as PyTrees, but not such a big deal I think?

@lucascolley
Copy link
Owner Author

I think the JIT will probably require a separate backend that uses JAX-specific things. In particular, for:

converged = xp_vector_norm(r, axis=-1) < atol
if xp.all(converged):
    return x, 0

I think we will need to use something from https://docs.jax.dev/en/latest/control-flow.html#structured-control-flow-primitives.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants