-
Notifications
You must be signed in to change notification settings - Fork 0
wip: N-D LinearOperator and cg
#35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
scipy/sparse/linalg/_interface.py
Outdated
|
|
||
| return np.hstack([self.matvec(col.reshape(-1,1)) for col in X.T]) | ||
| # X.mT here? | ||
| return np.hstack([self.matvec(col.reshape(-1, 1)) for col in X.T]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the new matvec can handle batched dims, do we need hstack or can it just be done all in one go?
|
Yes! This looks good in general, basically what I was thinking. Will this error in all previous cases if someone passes an ND LinearOperator to an existing method. For instance, |
that is also TODO 👍 |
39c5712 to
fe03769
Compare
3898d37 to
e141baa
Compare
| atol = max(float(atol), float(rtol) * float(b_norm)) | ||
| atol = max(float(atol), float(rtol) * np.max(b_norm)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we support different atols for different systems?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not yet, may want to change this to min
| beta = rho_cur / rho_prev | ||
| p *= beta | ||
| beta = xpx.apply_where( | ||
| not converged, | ||
| (rho_cur, rho_prev), | ||
| lambda cur, prev: cur / prev, | ||
| fill_value=0.0, | ||
| xp=np | ||
| ) | ||
| p = (beta * p.T).T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@izaid this is the proposed implementation of 'masking' to match what the paper is doing for alpha and beta:
While waiting for the slowest solution to converge, all the previously converged solutions are not updated anymore; their corresponding steps α and β are set to zero. This stopped update allows to enforce that no overflow, underflow, or NaN is generated when dealing with non-zero vectors which have zero norm due to rounding errors.
e141baa to
b1e89ed
Compare
b1e89ed to
5e49f58
Compare
| maxiter = b.shape[-1] * 10 | ||
|
|
||
| dotprod = np.vdot if np.iscomplexobj(x) else np.dot | ||
| dotprod = np.vdot if np.iscomplexobj(x) else functools.partial(np.vecdot, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may need axis=-1 for the complex case too
7ceafca to
b6739f0
Compare
[skip ci]
[skip ci]
[skip ci]
7f4e133 to
0569d5b
Compare
[skip ci]
0569d5b to
fa54350
Compare
|
Prototype working with In [1]: import array_api_strict as xp; from scipy.sparse.linalg import cg, LinearOperator; import numpy as np
In [2]: def solve(N, batch, report_index=0, batched=False):
...: rng = np.random.default_rng(0)
...: M = rng.standard_normal((N, N))
...: M = xp.asarray(M)
...: reg = 1e-3
...:
...: if batched:
...: M = xp.broadcast_to(M[xp.newaxis, ...], (batch, *M.shape))
...:
...: def matvec(x):
...: return xp.squeeze(M.mT @ (M @ x[..., xp.newaxis]), axis=-1) + reg * x
...:
...: shape = (batch, N, N) if batched else (N, N)
...: A = LinearOperator(shape, matvec=matvec, dtype=xp.float64, xp=xp)
...:
...: b = rng.standard_normal(N)
...: b = xp.asarray(b)
...:
...: if batched:
...: b = xp.reshape(xp.arange(batch, dtype=xp.float64), (batch, 1)) * b
...: x, info = cg(A, b, atol=1e-8, maxiter=5000)
...: assert info == 0
...: print(f"{x[report_index, ...]}")
...: else:
...: for i in xp.arange(batch, dtype=xp.float64):
...: x, info = cg(A, i*b, atol=1e-8, maxiter=5000)
...: assert info == 0
...: if i == report_index:
...: print(x)
...:
In [3]: solve(5, 10, report_index=7)
Array([ 10.91985197, -5.53737923,
-6.96397906, -35.6473016 ,
13.48931722], dtype=array_api_strict.float64)
In [4]: solve(5, 10, report_index=7, batched=True)
Array([ 10.91985197, -5.53737923,
-6.96397906, -35.6473016 ,
13.48931722], dtype=array_api_strict.float64)
In [5]: import jax
...: jax.config.update("jax_enable_x64", True)
In [6]: import jax.numpy as xp
In [7]: solve(5, 10, report_index=7, batched=True)
[ 10.91985197 -5.53737923 -6.96397906 -35.6473016 13.48931722]For JIT I think we will need to do some fancy registration of the linear operator classes as PyTrees, along the lines of https://docs.jax.dev/en/latest/_autosummary/jax.tree_util.register_pytree_node.html#jax.tree_util.register_pytree_node. |
|
Oh that's great! Yes, might need to register LinearOperator as PyTrees, but not such a big deal I think? |
|
I think the JIT will probably require a separate backend that uses JAX-specific things. In particular, for: converged = xp_vector_norm(r, axis=-1) < atol
if xp.all(converged):
return x, 0I think we will need to use something from https://docs.jax.dev/en/latest/control-flow.html#structured-control-flow-primitives. |
TODO: error for all algorithms which require 2-D input
https://github.com/kokkos/kokkos-kernels/blob/develop/batched/sparse/impl/KokkosBatched_CG_Team_Impl.hpp
https://ieeexplore.ieee.org/document/10054414 section VI