-
Couldn't load subscription status.
- Fork 146
Rewrite inverse for triangular matrix #1612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
asifzubair
wants to merge
7
commits into
pymc-devs:main
Choose a base branch
from
asifzubair:azubair/enh-573-rewrite-inv-triangular
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+311
−2
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
7c11584
add triangular rewrite
asifzubair a1e19e7
use new decorator pattern, lapack trtri
asifzubair 8d54703
address review comments;
asifzubair 4c5e21d
fix tests & mypy issues
asifzubair c7af980
fix mypy error, fix tests tol, move tests
asifzubair 208316d
typo
asifzubair 14fec15
review comments: overwrite_a test + tri rewrite test
asifzubair File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,18 +8,22 @@ | |
| from pytensor import tensor as pt | ||
| from pytensor.compile import optdb | ||
| from pytensor.graph import Apply, FunctionGraph | ||
| from pytensor.graph.basic import Constant | ||
| from pytensor.graph.rewriting.basic import ( | ||
| copy_stack_trace, | ||
| dfs_rewriter, | ||
| node_rewriter, | ||
| ) | ||
| from pytensor.graph.rewriting.unify import OpPattern | ||
| from pytensor.scalar.basic import Abs, Log, Mul, Sign | ||
| from pytensor.scalar.basic import Mul as ScalarMul | ||
| from pytensor.scalar.basic import Sub as ScalarSub | ||
| from pytensor.tensor.basic import ( | ||
| AllocDiag, | ||
| ExtractDiag, | ||
| Eye, | ||
| TensorVariable, | ||
| Tri, | ||
| concatenate, | ||
| diag, | ||
| diagonal, | ||
|
|
@@ -46,12 +50,16 @@ | |
| ) | ||
| from pytensor.tensor.rewriting.blockwise import blockwise_of | ||
| from pytensor.tensor.slinalg import ( | ||
| LU, | ||
| QR, | ||
| BlockDiagonal, | ||
| Cholesky, | ||
| CholeskySolve, | ||
| LUFactor, | ||
| Solve, | ||
| SolveBase, | ||
| SolveTriangular, | ||
| TriangularInv, | ||
| _bilinear_solve_discrete_lyapunov, | ||
| block_diag, | ||
| cholesky, | ||
|
|
@@ -1017,3 +1025,96 @@ def scalar_solve_to_division(fgraph, node): | |
| copy_stack_trace(old_out, new_out) | ||
|
|
||
| return [new_out] | ||
|
|
||
|
|
||
| def _find_triangular_op(var): | ||
| """ | ||
| Inspects a variable to see if it's triangular. | ||
|
|
||
| Returns `True` if lower-triangular, `False` if upper-triangular, otherwise `None`. | ||
| """ | ||
| # Case 1: Check for an explicit tag | ||
| is_lower = getattr(var.tag, "lower_triangular", False) | ||
| is_upper = getattr(var.tag, "upper_triangular", False) | ||
| if is_lower or is_upper: | ||
| return is_lower | ||
|
|
||
| if not var.owner: | ||
| return None | ||
|
|
||
| op = var.owner.op | ||
| core_op = op.core_op if isinstance(op, Blockwise) else op | ||
|
|
||
| # Case 2: Check for direct creator Ops | ||
| if isinstance(core_op, Cholesky): | ||
| return core_op.lower | ||
|
|
||
| if isinstance(core_op, LU | LUFactor): | ||
| if var.owner.outputs[1] == var: | ||
| return (True, False) | ||
| if var.owner.outputs[2] == var: | ||
| return (False, True) | ||
|
|
||
| if isinstance(core_op, QR): | ||
| if var.owner.outputs[1] == var: | ||
| return (False, True) | ||
|
Comment on lines
+1054
to
+1060
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These cases are still returning a tuple |
||
|
|
||
| # pt.tri will get constant folded so no point re-writing ? | ||
asifzubair marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if isinstance(core_op, Tri): | ||
| k_node = var.owner.inputs[2] | ||
| if isinstance(k_node, Constant) and k_node.data == 0: | ||
| return True | ||
|
|
||
| # Case 3: tril/triu patterns which are implemented as Mul | ||
| if isinstance(core_op, Elemwise) and isinstance(core_op.scalar_op, ScalarMul): | ||
| other_inp = next( | ||
| (i for i in var.owner.inputs if i != var.owner.inputs[0]), None | ||
| ) | ||
|
|
||
| if other_inp is not None and other_inp.owner: | ||
| # Check for tril pattern: Mul(x, Tri(...)) | ||
| if isinstance(other_inp.owner.op, Tri): | ||
| k_node = other_inp.owner.inputs[2] | ||
| if isinstance(k_node, Constant) and k_node.data == 0: | ||
| return True # It's tril | ||
|
|
||
| # Check for triu pattern: Mul(x, Sub(1, Tri(k=-1))) | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| sub_op = other_inp.owner.op | ||
| if isinstance(sub_op, Elemwise) and isinstance(sub_op.scalar_op, ScalarSub): | ||
| sub_inputs = other_inp.owner.inputs | ||
| const_one = next( | ||
| (i for i in sub_inputs if isinstance(i, Constant) and i.data == 1), | ||
| None, | ||
| ) | ||
| tri_inp = next( | ||
| (i for i in sub_inputs if i.owner and isinstance(i.owner.op, Tri)), | ||
| None, | ||
| ) | ||
|
|
||
| if const_one is not None and tri_inp is not None: | ||
| k_node = tri_inp.owner.inputs[2] | ||
| if isinstance(k_node, Constant) and k_node.data == -1: | ||
| return False # It's triu | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| @register_canonicalize | ||
| @register_stabilize | ||
| @node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)]) | ||
| def rewrite_inv_to_triangular_solve(fgraph, node): | ||
| """ | ||
| This rewrite takes advantage of the fact that the inverse of a triangular | ||
| matrix can be computed more efficiently than the inverse of a general | ||
| matrix by using a triangular inv instead of a general matrix inverse. | ||
| """ | ||
|
|
||
| A = node.inputs[0] | ||
| is_lower = _find_triangular_op(A) | ||
| if is_lower is None: | ||
| return None | ||
|
|
||
| new_op = TriangularInv(lower=is_lower) | ||
| new_inv = new_op(A) | ||
| copy_stack_trace(node.outputs[0], new_inv) | ||
| return [new_inv] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.