-
Notifications
You must be signed in to change notification settings - Fork 146
Rewrite inverse for triangular matrix #1612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Rewrite inverse for triangular matrix #1612
Conversation
|
Hi @jessegrabowski , I haven't added a test yet, but if this approach is valid, I can add one. Please let me know. Thank you! 🙏 CC: @theorashid , @ColtAllen |
|
Can you post some timings showing that this is advantageous? Can you also time |
|
Thank you, @jessegrabowski . I'd be interested in seeing those numbers as well. Let me do that study and report back. Thanks, again 🙏 |
|
Hi @jessegrabowski , I did the study that was suggested. I used the underlying
On average, we see ~2X improvement when using Click to view benchmarking codeimport timeit
import numpy as np
import scipy.linalg
from scipy.linalg.lapack import dtrtri
matrix_sizes = [50, 100, 250, 500, 750, 1000, 2000]
n_repeats = 100
results = {}
for size in matrix_sizes:
print(f"Running for size {size}x{size}...")
A_tril = np.tril(np.random.rand(size, size))
A_tril[np.diag_indices(size)] += 1.0
I = np.eye(size)
t_inv = timeit.timeit(lambda: np.linalg.inv(A_tril), number=n_repeats)
t_solve = timeit.timeit(lambda: np.linalg.solve(A_tril, I), number=n_repeats)
t_solve_tri = timeit.timeit(
lambda: scipy.linalg.solve_triangular(A_tril, I, lower=True),
number=n_repeats
)
A_fortran = np.asfortranarray(A_tril)
t_dtrtri = timeit.timeit(
lambda: dtrtri(A_fortran, lower=1),
number=n_repeats
)
results[size] = {
"inv": t_inv,
"solve": t_solve,
"solve_triangular": t_solve_tri,
"dtrtri": t_dtrtri,
"inv_div_solve_tri": t_inv / t_solve_tri if t_solve_tri > 0 else 0,
"solve_tri_div_dtrtri": t_solve_tri / t_dtrtri if t_dtrtri > 0 else 0
} |
|
That's really awesome! Thanks for doing this study. Given these results, my suggestion would be to make a We can also then add a rewrite that changes |
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (74.73%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1612 +/- ##
==========================================
- Coverage 81.64% 81.62% -0.02%
==========================================
Files 244 244
Lines 53590 53683 +93
Branches 9438 9464 +26
==========================================
+ Hits 43752 43821 +69
- Misses 7356 7370 +14
- Partials 2482 2492 +10
🚀 New features to boost your workflow:
|
pytensor/tensor/rewriting/linalg.py
Outdated
| core_op = node.op.core_op | ||
| if not isinstance(core_op, ALL_INVERSE_OPS): | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've merged some changes recently, so you can basically put this in the tracks: #1594
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @ricardoV94 . Just checking, can I eliminate this conditional in favor of this decorator:
@node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)])7d7fcef to
34b3eb8
Compare
|
Hi @jessegrabowski , could you please review when you have the time ? I do checks similar to this PR, except, of course, we use the lapack solver. I'm also happy to include other operations like Also, if helpful, I can start a separate issue to track Please let me know. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the ping! This is looking really amazing, and it's getting very close 🥳
pytensor/tensor/rewriting/linalg.py
Outdated
| is_upper = getattr(var.tag, "upper_triangular", False) | ||
|
|
||
| if is_lower or is_upper: | ||
| return (is_lower, is_upper) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just returning one should be sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that this gives some flexibility for returning diagonal etc. (is_upper = True and is_lower = True). No strong opinion though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My preference would be for a separate is_diagonal helper, and to simplify this to just one return type.
Long term, I'm hoping we will have op-by-op matrix type inference, so these checks will be much easier.
pytensor/tensor/rewriting/linalg.py
Outdated
| is_lower, is_upper = triangular_info | ||
| if is_lower or is_upper: | ||
| new_op = TriangularInv(lower=is_lower) | ||
| return [new_op(A)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to copy_stack_trace
pytensor/tensor/slinalg.py
Outdated
| if info > 0: | ||
| raise np.linalg.LinAlgError("Singular matrix") | ||
| elif info < 0: | ||
| raise ValueError( | ||
| "illegal value in %d-th argument of internal trtri" % -info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer if we return np.full_like(x, np.nan) when the algorithm fails. This is what jax does -- its very frustrating to have iterative algorithms (like mcmc/sgd) totally stop because of an unstable linalg operation.
Check how the Cholesky Op handles it.
| x_chol = cholesky(x) | ||
| y_chol = inv(x_chol) | ||
| f_chol = function([x], y_chol) | ||
| assert any( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also check that there is a regular Inv Ops in the graph before compiling, but not after.
Sure, feel free to open an issue. If you prefer to wait for this to be merged, that's fine too. |
Thank you, @jessegrabowski ! 🙏 I took a stab at the comments. Notably, I added Please let me know your thoughts. Thanks, again! 🙏 |
|
So you know, you can run mypy locally from inside the pytensor project folder with |
|
For the failed float32 test, make sure you set the atol and rtol much more relaxed when config.floatX is float32. Check the other tests to see what we do. We need to think of a better way to test linalg routines at half-precision... |
|
Thank you, @jessegrabowski . Yes, sorry, I realized belatedly about the I ran it locally and it seems that I'm fighting this error: Should I redefine the For the test failure, I had tried following the idiom seen elsewhere: np.testing.assert_allclose(
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
)But perhaps I need to loosen up the requirement ( I also realized some of my tests are in the wrong location, the Also, a heads up, I have to do some convoluted checks for re-writing |
d6e519e to
7f92bf6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is really amazing, we're super close. Sorry for being slow, let's try to get this merged in the next couple days!
pytensor/tensor/rewriting/linalg.py
Outdated
| is_upper = getattr(var.tag, "upper_triangular", False) | ||
|
|
||
| if is_lower or is_upper: | ||
| return (is_lower, is_upper) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My preference would be for a separate is_diagonal helper, and to simplify this to just one return type.
Long term, I'm hoping we will have op-by-op matrix type inference, so these checks will be much easier.
tests/tensor/test_slinalg.py
Outdated
|
|
||
|
|
||
| @pytest.mark.parametrize("lower", [True, False]) | ||
| def test_triangular_inv_op(lower): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You also need to test that the overwriting is working correctly. To do this you need to use pytensor.In(x, mutable=overwrite_a) in pytensor.function. Otherwise input variables are always treated as immutable and never overwritten. Check here for a rough example to follow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jessegrabowski , I've tried implementing this test, but not feeling very confident. Could you please review and let me know your thoughts ? Thank you! 🙏
7f92bf6 to
427008d
Compare
add other conditions to trigger rewrite enhance TriInv Op add tests
427008d to
14fec15
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Double check the is_triangular check for the LU/QR cases, then I think this is done! Really great work!
| return (True, False) | ||
| if var.owner.outputs[2] == var: | ||
| return (False, True) | ||
|
|
||
| if isinstance(core_op, QR): | ||
| if var.owner.outputs[1] == var: | ||
| return (False, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These cases are still returning a tuple

Description
We add a rewrite for matrix inversion when the matrix is triangular.
We check three conditions:
OpisTriOpisCholeskyRelated Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1612.org.readthedocs.build/en/1612/