From 1d4eb5ba5c708db18c14102e09f60e8fb64ea577 Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Mon, 15 Sep 2025 21:41:11 -0500 Subject: [PATCH 1/8] add triangular rewrite --- pytensor/tensor/rewriting/linalg.py | 54 +++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 17a3ce9165..a0437dee5a 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -20,6 +20,7 @@ ExtractDiag, Eye, TensorVariable, + Tri, concatenate, diag, diagonal, @@ -1017,3 +1018,56 @@ def scalar_solve_to_division(fgraph, node): copy_stack_trace(old_out, new_out) return [new_out] + + +def _find_triangular_op(var): + """ + Inspects a variable to see if it's triangular. + + Returns a tuple (is_lower, is_upper) if triangular, otherwise None. + """ + + is_lower = getattr(var.tag, "lower_triangular", False) + is_upper = getattr(var.tag, "upper_triangular", False) + + if is_lower or is_upper: + return (is_lower, is_upper) + + if var.owner and isinstance(var.owner.op, Tri): + # The 'k' parameter of Tri determines the diagonal. + # k=0 is the main diagonal. + k = var.owner.op.k + if k == 0: + is_lower = var.owner.op.lower + return (is_lower, not is_lower) + + if var.owner and isinstance(var.owner.op, Blockwise): + core_op = var.owner.op.core_op + if isinstance(core_op, Cholesky): + return (core_op.lower, not core_op.lower) + + return None + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_inv_to_triangular_solve(fgraph, node): + """ + This rewrite takes advantage of the fact that the inverse of a triangular + matrix can be computed more efficiently than the inverse of a general + matrix by using a triangular solve instead of a general matrix inverse. + """ + core_op = node.op.core_op + if not isinstance(core_op, ALL_INVERSE_OPS): + return None + + A = node.inputs[0] + triangular_info = _find_triangular_op(A) + if triangular_info is None: + return None + + is_lower, is_upper = triangular_info + if is_lower or is_upper: + I = pt.eye(A.shape[0], dtype=A.dtype) + return [solve_triangular(A, I, lower=is_lower)] From 56c15082eafde514cbd36c7bc16d8ce213173bd0 Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Sun, 28 Sep 2025 22:55:24 -0500 Subject: [PATCH 2/8] use new decorator pattern, lapack trtri --- pytensor/tensor/rewriting/linalg.py | 19 +++--------- pytensor/tensor/slinalg.py | 26 +++++++++++++++- tests/tensor/rewriting/test_linalg.py | 43 +++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 16 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index a0437dee5a..302fb6bc26 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -20,7 +20,6 @@ ExtractDiag, Eye, TensorVariable, - Tri, concatenate, diag, diagonal, @@ -53,6 +52,7 @@ Solve, SolveBase, SolveTriangular, + TriangularInv, _bilinear_solve_discrete_lyapunov, block_diag, cholesky, @@ -1033,14 +1033,6 @@ def _find_triangular_op(var): if is_lower or is_upper: return (is_lower, is_upper) - if var.owner and isinstance(var.owner.op, Tri): - # The 'k' parameter of Tri determines the diagonal. - # k=0 is the main diagonal. - k = var.owner.op.k - if k == 0: - is_lower = var.owner.op.lower - return (is_lower, not is_lower) - if var.owner and isinstance(var.owner.op, Blockwise): core_op = var.owner.op.core_op if isinstance(core_op, Cholesky): @@ -1051,16 +1043,13 @@ def _find_triangular_op(var): @register_canonicalize @register_stabilize -@node_rewriter([Blockwise]) +@node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)]) def rewrite_inv_to_triangular_solve(fgraph, node): """ This rewrite takes advantage of the fact that the inverse of a triangular matrix can be computed more efficiently than the inverse of a general matrix by using a triangular solve instead of a general matrix inverse. """ - core_op = node.op.core_op - if not isinstance(core_op, ALL_INVERSE_OPS): - return None A = node.inputs[0] triangular_info = _find_triangular_op(A) @@ -1069,5 +1058,5 @@ def rewrite_inv_to_triangular_solve(fgraph, node): is_lower, is_upper = triangular_info if is_lower or is_upper: - I = pt.eye(A.shape[0], dtype=A.dtype) - return [solve_triangular(A, I, lower=is_lower)] + new_op = TriangularInv(lower=is_lower) + return [new_op(A)] diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index cf1358813e..d83d762142 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -20,7 +20,7 @@ from pytensor.tensor import math as ptm from pytensor.tensor.basic import as_tensor_variable, diagonal from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.nlinalg import kron, matrix_dot +from pytensor.tensor.nlinalg import MatrixInverse, kron, matrix_dot from pytensor.tensor.shape import reshape from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.variable import TensorVariable @@ -1016,6 +1016,30 @@ def solve_triangular( return cast(TensorVariable, ret) +class TriangularInv(MatrixInverse): + """ + Computes the inverse of a triangular matrix. + """ + + __props__ = ("lower",) + + def __init__(self, lower=True): + self.lower = lower + + def perform(self, node, inputs, outputs): + (x,) = inputs + (z,) = outputs + (dtrtri,) = get_lapack_funcs(("trtri",), (x,)) + inv, info = dtrtri(x, lower=self.lower, overwrite_c=True) + if info > 0: + raise np.linalg.LinAlgError("Singular matrix") + elif info < 0: + raise ValueError( + "illegal value in %d-th argument of internal trtri" % -info + ) + z[0] = inv + + class Solve(SolveBase): """ Solve a system of linear equations. diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 515120e446..ab866d8b27 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -23,6 +23,7 @@ MatrixInverse, MatrixPinv, SLogDet, + inv, matrix_inverse, svd, ) @@ -34,6 +35,7 @@ Solve, SolveBase, SolveTriangular, + TriangularInv, cho_solve, cholesky, solve, @@ -1060,3 +1062,44 @@ def solve_op_in_graph(graph): np.testing.assert_allclose( f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5 ) + + +def test_triangular_inv_op(): + x = matrix("x") + f_lower = function([x], Blockwise(TriangularInv(lower=True))(x)) + f_upper = function([x], Blockwise(TriangularInv(lower=False))(x)) + + # Test lower + a = np.tril(np.random.rand(5, 5) + 0.1) + a_inv = f_lower(a) + expected_inv = np.linalg.inv(a) + np.testing.assert_allclose( + np.tril(a_inv), np.tril(expected_inv), rtol=1e-5, atol=1e-7 + ) + + # Test upper + a = np.triu(np.random.rand(5, 5) + 0.1) + a_inv = f_upper(a) + expected_inv = np.linalg.inv(a) + np.testing.assert_allclose( + np.triu(a_inv), np.triu(expected_inv), rtol=1e-5, atol=1e-7 + ) + + +def test_inv_to_triangular_inv_rewrite(): + x = matrix("x") + + x_chol = cholesky(x) + y_chol = inv(x_chol) + f_chol = function([x], y_chol) + assert any( + isinstance(node.op, TriangularInv) + or (hasattr(node.op, "core_op") and isinstance(node.op.core_op, TriangularInv)) + for node in f_chol.maker.fgraph.apply_nodes + ) + + a = np.random.rand(5, 5) + a = np.dot(a, a.T) + np.eye(5) * 0.1 # ensure positive definite + np.testing.assert_allclose( + f_chol(a), np.linalg.inv(np.linalg.cholesky(a)), rtol=1e-5, atol=1e-7 + ) From 79829e25bd3db960a3be936f908367da15c0cdcc Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Sun, 5 Oct 2025 12:16:28 -0500 Subject: [PATCH 3/8] address review comments; add other conditions to trigger rewrite enhance TriInv Op add tests --- pytensor/tensor/rewriting/linalg.py | 75 ++++++++++++-- pytensor/tensor/slinalg.py | 38 ++++++-- tests/tensor/rewriting/test_linalg.py | 135 +++++++++++++++++++++----- 3 files changed, 207 insertions(+), 41 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 302fb6bc26..060fb5e8b8 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -8,6 +8,7 @@ from pytensor import tensor as pt from pytensor.compile import optdb from pytensor.graph import Apply, FunctionGraph +from pytensor.graph.basic import Constant from pytensor.graph.rewriting.basic import ( copy_stack_trace, dfs_rewriter, @@ -15,11 +16,14 @@ ) from pytensor.graph.rewriting.unify import OpPattern from pytensor.scalar.basic import Abs, Log, Mul, Sign +from pytensor.scalar.basic import Mul as ScalarMul +from pytensor.scalar.basic import Sub as ScalarSub from pytensor.tensor.basic import ( AllocDiag, ExtractDiag, Eye, TensorVariable, + Tri, concatenate, diag, diagonal, @@ -46,9 +50,12 @@ ) from pytensor.tensor.rewriting.blockwise import blockwise_of from pytensor.tensor.slinalg import ( + LU, + QR, BlockDiagonal, Cholesky, CholeskySolve, + LUFactor, Solve, SolveBase, SolveTriangular, @@ -1026,17 +1033,69 @@ def _find_triangular_op(var): Returns a tuple (is_lower, is_upper) if triangular, otherwise None. """ - + # Case 1: Check for an explicit tag is_lower = getattr(var.tag, "lower_triangular", False) is_upper = getattr(var.tag, "upper_triangular", False) - if is_lower or is_upper: return (is_lower, is_upper) - if var.owner and isinstance(var.owner.op, Blockwise): - core_op = var.owner.op.core_op - if isinstance(core_op, Cholesky): - return (core_op.lower, not core_op.lower) + if not var.owner: + return None + + op = var.owner.op + core_op = op.core_op if isinstance(op, Blockwise) else op + + # Case 2: Check for direct creator Ops + if isinstance(core_op, Cholesky): + return (core_op.lower, not core_op.lower) + + if isinstance(core_op, LU | LUFactor): + if var.owner.outputs[1] == var: + return (True, False) + if var.owner.outputs[2] == var: + return (False, True) + + if isinstance(core_op, QR): + if var.owner.outputs[1] == var: + return (False, True) + + # pt.tri will get constant folded so no point re-writing ? + # if isinstance(core_op, Tri): + # k_node = var.owner.inputs[2] + # if isinstance(k_node, Constant) and k_node.data == 0: + # print('re-writing ... ') + # return (True, False) + + # Case 3: tril/triu patterns which are implemented as Mul + if isinstance(core_op, Elemwise) and isinstance(core_op.scalar_op, ScalarMul): + other_inp = next( + (i for i in var.owner.inputs if i != var.owner.inputs[0]), None + ) + + if other_inp is not None and other_inp.owner: + # Check for tril pattern: Mul(x, Tri(...)) + if isinstance(other_inp.owner.op, Tri): + k_node = other_inp.owner.inputs[2] + if isinstance(k_node, Constant) and k_node.data == 0: + return (True, False) # It's tril + + # Check for triu pattern: Mul(x, Sub(1, Tri(k=-1))) + sub_op = other_inp.owner.op + if isinstance(sub_op, Elemwise) and isinstance(sub_op.scalar_op, ScalarSub): + sub_inputs = other_inp.owner.inputs + const_one = next( + (i for i in sub_inputs if isinstance(i, Constant) and i.data == 1), + None, + ) + tri_inp = next( + (i for i in sub_inputs if i.owner and isinstance(i.owner.op, Tri)), + None, + ) + + if const_one is not None and tri_inp is not None: + k_node = tri_inp.owner.inputs[2] + if isinstance(k_node, Constant) and k_node.data == -1: + return (False, True) # It's triu return None @@ -1059,4 +1118,6 @@ def rewrite_inv_to_triangular_solve(fgraph, node): is_lower, is_upper = triangular_info if is_lower or is_upper: new_op = TriangularInv(lower=is_lower) - return [new_op(A)] + new_inv = new_op(A) + copy_stack_trace(node.outputs[0], new_inv) + return [new_inv] diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index d83d762142..0f00b43cc1 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1021,24 +1021,46 @@ class TriangularInv(MatrixInverse): Computes the inverse of a triangular matrix. """ - __props__ = ("lower",) + __props__ = ("lower", "on_error", "overwrite_a") - def __init__(self, lower=True): + def __init__(self, lower=True, on_error="raise", overwrite_a=False): self.lower = lower + if on_error not in ("raise", "nan"): + raise ValueError('on_error must be one of "raise" or "nan"') + self.on_error = on_error + self.overwrite_a = overwrite_a + + if self.overwrite_a: + self.destroy_map = {0: [0]} def perform(self, node, inputs, outputs): (x,) = inputs (z,) = outputs (dtrtri,) = get_lapack_funcs(("trtri",), (x,)) inv, info = dtrtri(x, lower=self.lower, overwrite_c=True) - if info > 0: - raise np.linalg.LinAlgError("Singular matrix") - elif info < 0: - raise ValueError( - "illegal value in %d-th argument of internal trtri" % -info - ) + if info != 0: + if self.on_error == "nan": + z[0] = np.full_like(x, np.nan) + return + elif info > 0: + raise np.linalg.LinAlgError("Singular matrix") + elif info < 0: + raise ValueError( + f"illegal value in {-info}-th argument of internal trtri" + ) z[0] = inv + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + """ + Allows this Op to overwrite its input buffer with its output. + """ + if not allowed_inplace_inputs: + return self + + new_props = self._props_dict() + new_props["overwrite_a"] = True + return type(self)(**new_props) + class Solve(SolveBase): """ diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index ab866d8b27..f43d1e997e 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -11,8 +11,10 @@ from pytensor.compile import get_default_mode from pytensor.configdefaults import config from pytensor.graph import ancestors +from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.tensor import swapaxes +from pytensor.tensor.basic import tril, triu from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import dot, matmul @@ -38,6 +40,8 @@ TriangularInv, cho_solve, cholesky, + lu, + qr, solve, solve_triangular, ) @@ -1064,42 +1068,121 @@ def solve_op_in_graph(graph): ) -def test_triangular_inv_op(): +@pytest.mark.parametrize("lower", [True, False]) +def test_triangular_inv_op(lower): + """Tests the TriangularInv Op directly.""" x = matrix("x") - f_lower = function([x], Blockwise(TriangularInv(lower=True))(x)) - f_upper = function([x], Blockwise(TriangularInv(lower=False))(x)) + f = function([x], TriangularInv(lower=lower)(x)) - # Test lower - a = np.tril(np.random.rand(5, 5) + 0.1) - a_inv = f_lower(a) - expected_inv = np.linalg.inv(a) - np.testing.assert_allclose( - np.tril(a_inv), np.tril(expected_inv), rtol=1e-5, atol=1e-7 - ) + if lower: + a = np.tril(np.random.rand(5, 5) + 0.1) + else: + a = np.triu(np.random.rand(5, 5) + 0.1) - # Test upper - a = np.triu(np.random.rand(5, 5) + 0.1) - a_inv = f_upper(a) + a_inv = f(a) expected_inv = np.linalg.inv(a) - np.testing.assert_allclose( - np.triu(a_inv), np.triu(expected_inv), rtol=1e-5, atol=1e-7 + + # Clean the NumPy result before comparing. + if lower: + expected_inv = np.tril(expected_inv) + else: + expected_inv = np.triu(expected_inv) + + # The inverse of a triangular matrix is also triangular. + # We should check the full matrix, not just a part of it. + assert_allclose( + a_inv, expected_inv, rtol=1e-7 if config.floatX == "float64" else 1e-5 ) -def test_inv_to_triangular_inv_rewrite(): +def test_triangular_inv_op_nan_on_error(): + """ + Tests the `on_error='nan'` functionality of the TriangularInv Op. + """ x = matrix("x") + f_nan = function([x], TriangularInv(on_error="nan")(x)) + + # Create a singular triangular matrix (zero on the diagonal) + a_singular = np.tril(np.random.rand(5, 5)) + a_singular[2, 2] = 0 - x_chol = cholesky(x) - y_chol = inv(x_chol) - f_chol = function([x], y_chol) - assert any( - isinstance(node.op, TriangularInv) - or (hasattr(node.op, "core_op") and isinstance(node.op.core_op, TriangularInv)) - for node in f_chol.maker.fgraph.apply_nodes + res = f_nan(a_singular) + assert np.all(np.isnan(res)) + + +def _check_op_in_graph(fgraph, op_type, present=True): + """Helper to check if an Op is in a graph.""" + + # We use type() instead of isinstance() to avoid matching subclasses + # (e.g., finding TriangularInv when we're looking for MatrixInverse). + found = any( + type(node.op) is op_type + or (hasattr(node.op, "core_op") and type(node.op.core_op) is op_type) + for node in fgraph.apply_nodes ) + if present: + assert found, f"{op_type.__name__} not found in graph" + else: + assert not found, f"{op_type.__name__} unexpectedly found in graph" + + +rewrite_cases = { + "tril": ( + lambda x: tril(x), + lambda a: np.tril(a), + ), + "triu": ( + lambda x: triu(x), + lambda a: np.triu(a), + ), + "cholesky": ( + lambda x: cholesky(x), + lambda a: np.linalg.cholesky(a), + ), + "lu_L": ( + lambda x: lu(x)[1], + lambda a: scipy.linalg.lu(a)[1], + ), + "lu_U": ( + lambda x: lu(x)[2], + lambda a: scipy.linalg.lu(a)[2], + ), + "qr_R": ( + lambda x: qr(x)[1], + lambda a: np.linalg.qr(a)[1], + ), +} + + +@pytest.mark.parametrize("case", rewrite_cases.keys()) +def test_inv_to_triangular_inv_rewrite(case): + """ + Tests the rewrite of inv(triangular) -> TriangularInv. + """ + x = matrix("x") + build_tri, _ = rewrite_cases[case] + x_tri = build_tri(x) + y_inv = inv(x_tri) + # Check graph BEFORE compilation + pre_compile_fgraph = FunctionGraph([x], [y_inv], clone=False) + _check_op_in_graph(pre_compile_fgraph, MatrixInverse, present=True) + _check_op_in_graph(pre_compile_fgraph, TriangularInv, present=False) + + # Trigger the rewrite + f = function([x], y_inv) + + # Check graph AFTER compilation + post_compile_fgraph = f.maker.fgraph + _check_op_in_graph(post_compile_fgraph, TriangularInv, present=True) + _check_op_in_graph(post_compile_fgraph, MatrixInverse, present=False) + + # Check numerical correctness a = np.random.rand(5, 5) - a = np.dot(a, a.T) + np.eye(5) * 0.1 # ensure positive definite - np.testing.assert_allclose( - f_chol(a), np.linalg.inv(np.linalg.cholesky(a)), rtol=1e-5, atol=1e-7 + a = np.dot(a, a.T) + np.eye(5) # Make positive definite for Cholesky + pytensor_result = f(a) + _, numpy_tri_func = rewrite_cases[case] + numpy_result = np.linalg.inv(numpy_tri_func(a)) + assert_allclose( + pytensor_result, numpy_result, rtol=1e-7 if config.floatX == "float64" else 1e-5 ) From 10dfeaa79be1f5865e32a726abbf86ac55f1d308 Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Sun, 5 Oct 2025 17:33:03 -0500 Subject: [PATCH 4/8] fix tests & mypy issues --- pytensor/tensor/slinalg.py | 2 +- tests/tensor/rewriting/test_linalg.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 0f00b43cc1..0953085b51 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1057,7 +1057,7 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": if not allowed_inplace_inputs: return self - new_props = self._props_dict() + new_props = self._props_dict() # type: ignore new_props["overwrite_a"] = True return type(self)(**new_props) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index f43d1e997e..967c87df7c 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -1071,13 +1071,13 @@ def solve_op_in_graph(graph): @pytest.mark.parametrize("lower", [True, False]) def test_triangular_inv_op(lower): """Tests the TriangularInv Op directly.""" - x = matrix("x") + x = matrix("x", dtype=config.floatX) f = function([x], TriangularInv(lower=lower)(x)) if lower: - a = np.tril(np.random.rand(5, 5) + 0.1) + a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX) else: - a = np.triu(np.random.rand(5, 5) + 0.1) + a = np.triu(np.random.rand(5, 5) + 0.1).astype(config.floatX) a_inv = f(a) expected_inv = np.linalg.inv(a) @@ -1099,12 +1099,13 @@ def test_triangular_inv_op_nan_on_error(): """ Tests the `on_error='nan'` functionality of the TriangularInv Op. """ - x = matrix("x") + x = matrix("x", dtype=config.floatX) f_nan = function([x], TriangularInv(on_error="nan")(x)) # Create a singular triangular matrix (zero on the diagonal) a_singular = np.tril(np.random.rand(5, 5)) a_singular[2, 2] = 0 + a_singular = a_singular.astype(config.floatX) res = f_nan(a_singular) assert np.all(np.isnan(res)) @@ -1159,7 +1160,7 @@ def test_inv_to_triangular_inv_rewrite(case): """ Tests the rewrite of inv(triangular) -> TriangularInv. """ - x = matrix("x") + x = matrix("x", dtype=config.floatX) build_tri, _ = rewrite_cases[case] x_tri = build_tri(x) y_inv = inv(x_tri) @@ -1179,7 +1180,9 @@ def test_inv_to_triangular_inv_rewrite(case): # Check numerical correctness a = np.random.rand(5, 5) - a = np.dot(a, a.T) + np.eye(5) # Make positive definite for Cholesky + a = (np.dot(a, a.T) + np.eye(5)).astype( + config.floatX + ) # Make positive definite for Cholesky pytensor_result = f(a) _, numpy_tri_func = rewrite_cases[case] numpy_result = np.linalg.inv(numpy_tri_func(a)) From 3cde3514074e5e09a0c87c8ac9fe7834952c5972 Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Thu, 9 Oct 2025 19:44:57 -0500 Subject: [PATCH 5/8] fix mypy error, fix tests tol, move tests --- pytensor/tensor/nlinalg.py | 2 +- tests/tensor/rewriting/test_linalg.py | 48 ++------------------------- tests/tensor/test_slinalg.py | 43 ++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 47 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index a74eff129d..3300cf1f79 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -107,7 +107,7 @@ class MatrixInverse(Op): """ - __props__ = () + __props__: tuple[str, ...] = () gufunc_signature = "(m,m)->(m,m)" gufunc_spec = ("numpy.linalg.inv", 1, 1) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 967c87df7c..186837cd3d 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -1068,49 +1068,6 @@ def solve_op_in_graph(graph): ) -@pytest.mark.parametrize("lower", [True, False]) -def test_triangular_inv_op(lower): - """Tests the TriangularInv Op directly.""" - x = matrix("x", dtype=config.floatX) - f = function([x], TriangularInv(lower=lower)(x)) - - if lower: - a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX) - else: - a = np.triu(np.random.rand(5, 5) + 0.1).astype(config.floatX) - - a_inv = f(a) - expected_inv = np.linalg.inv(a) - - # Clean the NumPy result before comparing. - if lower: - expected_inv = np.tril(expected_inv) - else: - expected_inv = np.triu(expected_inv) - - # The inverse of a triangular matrix is also triangular. - # We should check the full matrix, not just a part of it. - assert_allclose( - a_inv, expected_inv, rtol=1e-7 if config.floatX == "float64" else 1e-5 - ) - - -def test_triangular_inv_op_nan_on_error(): - """ - Tests the `on_error='nan'` functionality of the TriangularInv Op. - """ - x = matrix("x", dtype=config.floatX) - f_nan = function([x], TriangularInv(on_error="nan")(x)) - - # Create a singular triangular matrix (zero on the diagonal) - a_singular = np.tril(np.random.rand(5, 5)) - a_singular[2, 2] = 0 - a_singular = a_singular.astype(config.floatX) - - res = f_nan(a_singular) - assert np.all(np.isnan(res)) - - def _check_op_in_graph(fgraph, op_type, present=True): """Helper to check if an Op is in a graph.""" @@ -1186,6 +1143,5 @@ def test_inv_to_triangular_inv_rewrite(case): pytensor_result = f(a) _, numpy_tri_func = rewrite_cases[case] numpy_result = np.linalg.inv(numpy_tri_func(a)) - assert_allclose( - pytensor_result, numpy_result, rtol=1e-7 if config.floatX == "float64" else 1e-5 - ) + atol = rtol = 1e-8 if config.floatX.endswith("64") else 1e-4 + np.testing.assert_allclose(pytensor_result, numpy_result, rtol=rtol, atol=atol) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 4140331036..8eb0972cad 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -19,6 +19,7 @@ Solve, SolveBase, SolveTriangular, + TriangularInv, block_diag, cho_solve, cholesky, @@ -1235,3 +1236,45 @@ def _test_fn(x, case=2, mode="reduced"): utt.verify_grad( partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random ) + + +@pytest.mark.parametrize("lower", [True, False]) +def test_triangular_inv_op(lower): + """Tests the TriangularInv Op directly.""" + x = matrix("x", dtype=config.floatX) + f = function([x], TriangularInv(lower=lower)(x)) + + if lower: + a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX) + else: + a = np.triu(np.random.rand(5, 5) + 0.1).astype(config.floatX) + + a_inv = f(a) + expected_inv = np.linalg.inv(a) + + # Clean the NumPy result before comparing. + if lower: + expected_inv = np.tril(expected_inv) + else: + expected_inv = np.triu(expected_inv) + + # The inverse of a triangular matrix is also triangular. + # We should check the full matrix, not just a part of it. + atol = rtol = 1e-8 if config.floatX.endswith("64") else 1e-4 + np.testing.assert_allclose(a_inv, expected_inv, rtol=rtol, atol=atol) + + +def test_triangular_inv_op_nan_on_error(): + """ + Tests the `on_error='nan'` functionality of the TriangularInv Op. + """ + x = matrix("x", dtype=config.floatX) + f_nan = function([x], TriangularInv(on_error="nan")(x)) + + # Create a singular triangular matrix (zero on the diagonal) + a_singular = np.tril(np.random.rand(5, 5)) + a_singular[2, 2] = 0 + a_singular = a_singular.astype(config.floatX) + + res = f_nan(a_singular) + assert np.all(np.isnan(res)) From 24057936ef82496f237adcd15a4f149ebfedddfb Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Thu, 9 Oct 2025 19:57:19 -0500 Subject: [PATCH 6/8] typo --- pytensor/tensor/rewriting/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 060fb5e8b8..5fce7fe8fd 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1107,7 +1107,7 @@ def rewrite_inv_to_triangular_solve(fgraph, node): """ This rewrite takes advantage of the fact that the inverse of a triangular matrix can be computed more efficiently than the inverse of a general - matrix by using a triangular solve instead of a general matrix inverse. + matrix by using a triangular inv instead of a general matrix inverse. """ A = node.inputs[0] From c40159e8eea31cadca88caac765e39e70d955613 Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Sat, 18 Oct 2025 16:07:24 -0500 Subject: [PATCH 7/8] review comments: overwrite_a test + tri rewrite test --- pytensor/tensor/rewriting/linalg.py | 33 +++++----- pytensor/tensor/slinalg.py | 25 +++++++- tests/tensor/rewriting/test_linalg.py | 8 +-- tests/tensor/test_slinalg.py | 87 ++++++++++++++++----------- 4 files changed, 92 insertions(+), 61 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 5fce7fe8fd..b1ccef32e5 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1031,13 +1031,13 @@ def _find_triangular_op(var): """ Inspects a variable to see if it's triangular. - Returns a tuple (is_lower, is_upper) if triangular, otherwise None. + Returns `True` if lower-triangular, `False` if upper-triangular, otherwise `None`. """ # Case 1: Check for an explicit tag is_lower = getattr(var.tag, "lower_triangular", False) is_upper = getattr(var.tag, "upper_triangular", False) if is_lower or is_upper: - return (is_lower, is_upper) + return is_lower if not var.owner: return None @@ -1047,7 +1047,7 @@ def _find_triangular_op(var): # Case 2: Check for direct creator Ops if isinstance(core_op, Cholesky): - return (core_op.lower, not core_op.lower) + return core_op.lower if isinstance(core_op, LU | LUFactor): if var.owner.outputs[1] == var: @@ -1060,11 +1060,10 @@ def _find_triangular_op(var): return (False, True) # pt.tri will get constant folded so no point re-writing ? - # if isinstance(core_op, Tri): - # k_node = var.owner.inputs[2] - # if isinstance(k_node, Constant) and k_node.data == 0: - # print('re-writing ... ') - # return (True, False) + if isinstance(core_op, Tri): + k_node = var.owner.inputs[2] + if isinstance(k_node, Constant) and k_node.data == 0: + return True # Case 3: tril/triu patterns which are implemented as Mul if isinstance(core_op, Elemwise) and isinstance(core_op.scalar_op, ScalarMul): @@ -1077,7 +1076,7 @@ def _find_triangular_op(var): if isinstance(other_inp.owner.op, Tri): k_node = other_inp.owner.inputs[2] if isinstance(k_node, Constant) and k_node.data == 0: - return (True, False) # It's tril + return True # It's tril # Check for triu pattern: Mul(x, Sub(1, Tri(k=-1))) sub_op = other_inp.owner.op @@ -1095,7 +1094,7 @@ def _find_triangular_op(var): if const_one is not None and tri_inp is not None: k_node = tri_inp.owner.inputs[2] if isinstance(k_node, Constant) and k_node.data == -1: - return (False, True) # It's triu + return False # It's triu return None @@ -1111,13 +1110,11 @@ def rewrite_inv_to_triangular_solve(fgraph, node): """ A = node.inputs[0] - triangular_info = _find_triangular_op(A) - if triangular_info is None: + is_lower = _find_triangular_op(A) + if is_lower is None: return None - is_lower, is_upper = triangular_info - if is_lower or is_upper: - new_op = TriangularInv(lower=is_lower) - new_inv = new_op(A) - copy_stack_trace(node.outputs[0], new_inv) - return [new_inv] + new_op = TriangularInv(lower=is_lower) + new_inv = new_op(A) + copy_stack_trace(node.outputs[0], new_inv) + return [new_inv] diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 0953085b51..1e31b2b8ef 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1036,8 +1036,27 @@ def __init__(self, lower=True, on_error="raise", overwrite_a=False): def perform(self, node, inputs, outputs): (x,) = inputs (z,) = outputs - (dtrtri,) = get_lapack_funcs(("trtri",), (x,)) - inv, info = dtrtri(x, lower=self.lower, overwrite_c=True) + (trtri,) = get_lapack_funcs(("trtri",), (x,)) + + # Check if we want to overwrite and if the input is C-contiguous + c_contiguous_input = self.overwrite_a and x.flags["C_CONTIGUOUS"] + if c_contiguous_input: + # Transpose C-contiguous to F-contiguous + x_in = x.T + lower_flag = not self.lower + overwrite_flag = True + else: + # Use original matrix and flags + x_in = x + lower_flag = self.lower + overwrite_flag = self.overwrite_a + + # Call trtri with the potentially transposed input and correct flags + # Use overwrite_c (LAPACK flag for trtri) based on our logic + inv_maybe_transposed, info = trtri( + x_in, lower=lower_flag, overwrite_c=overwrite_flag + ) + if info != 0: if self.on_error == "nan": z[0] = np.full_like(x, np.nan) @@ -1048,7 +1067,7 @@ def perform(self, node, inputs, outputs): raise ValueError( f"illegal value in {-info}-th argument of internal trtri" ) - z[0] = inv + z[0] = inv_maybe_transposed.T if c_contiguous_input else inv_maybe_transposed def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": """ diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 186837cd3d..cee79d158d 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -14,7 +14,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.tensor import swapaxes -from pytensor.tensor.basic import tril, triu +from pytensor.tensor.basic import triu from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import dot, matmul @@ -1085,9 +1085,9 @@ def _check_op_in_graph(fgraph, op_type, present=True): rewrite_cases = { - "tril": ( - lambda x: tril(x), - lambda a: np.tril(a), + "tri": ( + lambda x: pt.tri(x.shape[0], x.shape[1], k=0, dtype=x.dtype), + lambda a: np.tri(N=a.shape[0], M=a.shape[1], k=0, dtype=a.dtype), ), "triu": ( lambda x: triu(x), diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 8eb0972cad..dc3828c0e4 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -8,6 +8,7 @@ import scipy from scipy import linalg as scipy_linalg +import pytensor from pytensor import function, grad from pytensor import tensor as pt from pytensor.configdefaults import config @@ -1238,43 +1239,57 @@ def _test_fn(x, case=2, mode="reduced"): ) -@pytest.mark.parametrize("lower", [True, False]) -def test_triangular_inv_op(lower): - """Tests the TriangularInv Op directly.""" - x = matrix("x", dtype=config.floatX) - f = function([x], TriangularInv(lower=lower)(x)) - - if lower: - a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX) - else: - a = np.triu(np.random.rand(5, 5) + 0.1).astype(config.floatX) - - a_inv = f(a) - expected_inv = np.linalg.inv(a) - - # Clean the NumPy result before comparing. - if lower: - expected_inv = np.tril(expected_inv) - else: - expected_inv = np.triu(expected_inv) - - # The inverse of a triangular matrix is also triangular. - # We should check the full matrix, not just a part of it. - atol = rtol = 1e-8 if config.floatX.endswith("64") else 1e-4 - np.testing.assert_allclose(a_inv, expected_inv, rtol=rtol, atol=atol) - - -def test_triangular_inv_op_nan_on_error(): +class TestTriangularInv: """ - Tests the `on_error='nan'` functionality of the TriangularInv Op. + Tests for the `TriangularInv` `Op`. """ - x = matrix("x", dtype=config.floatX) - f_nan = function([x], TriangularInv(on_error="nan")(x)) - # Create a singular triangular matrix (zero on the diagonal) - a_singular = np.tril(np.random.rand(5, 5)) - a_singular[2, 2] = 0 - a_singular = a_singular.astype(config.floatX) + @pytest.mark.parametrize("lower", [True, False]) + def test_triangular_inv_op(self, lower): + """Tests the TriangularInv Op directly.""" + x = matrix("x", dtype=config.floatX) + f = function([x], TriangularInv(lower=lower)(x)) + + if lower: + a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX) + else: + a = np.triu(np.random.rand(5, 5) + 0.1).astype(config.floatX) + + a_inv = f(a) + expected_inv = np.linalg.inv(a) + + # The inverse of a triangular matrix is also triangular. + # We should check the full matrix, not just a part of it. + atol = rtol = 1e-8 if config.floatX.endswith("64") else 1e-4 + np.testing.assert_allclose(a_inv, expected_inv, rtol=rtol, atol=atol) + + def test_triangular_inv_op_nan_on_error(self): + """ + Tests the `on_error='nan'` functionality of the TriangularInv Op. + """ + x = matrix("x", dtype=config.floatX) + f_nan = function([x], TriangularInv(on_error="nan")(x)) + + # Create a singular triangular matrix (zero on the diagonal) + a_singular = np.tril(np.random.rand(5, 5)) + a_singular[2, 2] = 0 + a_singular = a_singular.astype(config.floatX) + + res = f_nan(a_singular) + assert np.all(np.isnan(res)) + + @pytest.mark.parametrize("overwrite_a", [True, False]) + def test_triangular_inv_op_inplace(self, overwrite_a): + """Tests the TriangularInv Op directly.""" + x = matrix("x", dtype=config.floatX) + f = function( + [pytensor.In(x, mutable=overwrite_a)], + TriangularInv(overwrite_a=overwrite_a)(x), + accept_inplace=True, + ) + + a = np.tril(np.random.rand(5, 5) + 0.1).astype(config.floatX) + a_copy = a.copy() + f(a) - res = f_nan(a_singular) - assert np.all(np.isnan(res)) + assert overwrite_a == (not np.allclose(a, a_copy)) From 5a660c63866939f5c67ae4ef91b2939a94b883e7 Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Mon, 27 Oct 2025 21:17:52 -0500 Subject: [PATCH 8/8] improve test coverage, fix return types --- pytensor/tensor/rewriting/linalg.py | 7 +++---- tests/tensor/rewriting/test_linalg.py | 7 +++++-- tests/tensor/test_slinalg.py | 22 +++++++++++++++++++--- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index b1ccef32e5..ba545aaff2 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1051,15 +1051,14 @@ def _find_triangular_op(var): if isinstance(core_op, LU | LUFactor): if var.owner.outputs[1] == var: - return (True, False) + return True if var.owner.outputs[2] == var: - return (False, True) + return False if isinstance(core_op, QR): if var.owner.outputs[1] == var: - return (False, True) + return False - # pt.tri will get constant folded so no point re-writing ? if isinstance(core_op, Tri): k_node = var.owner.inputs[2] if isinstance(k_node, Constant) and k_node.data == 0: diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index cee79d158d..5160218ff7 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -14,7 +14,6 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.tensor import swapaxes -from pytensor.tensor.basic import triu from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import dot, matmul @@ -1089,8 +1088,12 @@ def _check_op_in_graph(fgraph, op_type, present=True): lambda x: pt.tri(x.shape[0], x.shape[1], k=0, dtype=x.dtype), lambda a: np.tri(N=a.shape[0], M=a.shape[1], k=0, dtype=a.dtype), ), + "tril": ( + lambda x: pt.tril(x), + lambda a: np.tril(a), + ), "triu": ( - lambda x: triu(x), + lambda x: pt.triu(x), lambda a: np.triu(a), ), "cholesky": ( diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index dc3828c0e4..efd5f69b8d 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -1263,10 +1263,26 @@ def test_triangular_inv_op(self, lower): atol = rtol = 1e-8 if config.floatX.endswith("64") else 1e-4 np.testing.assert_allclose(a_inv, expected_inv, rtol=rtol, atol=atol) + def test_triangular_inv_op_bad_on_error(self): + """Tests that a bad `on_error` value raises a ValueError.""" + with pytest.raises(ValueError, match="on_error must be one of"): + TriangularInv(on_error="foo") + + def test_triangular_inv_op_raise_on_error(self): + """Tests the default `on_error='raise'` functionality.""" + x = matrix("x", dtype=config.floatX) + f_raise = function([x], TriangularInv()(x)) + + # Create a singular triangular matrix (zero on the diagonal) + a_singular = np.tril(np.random.rand(5, 5)) + a_singular[2, 2] = 0 + a_singular = a_singular.astype(config.floatX) + + with pytest.raises(np.linalg.LinAlgError, match="Singular matrix"): + f_raise(a_singular) + def test_triangular_inv_op_nan_on_error(self): - """ - Tests the `on_error='nan'` functionality of the TriangularInv Op. - """ + """Tests the `on_error='nan'` functionality of the TriangularInv Op.""" x = matrix("x", dtype=config.floatX) f_nan = function([x], TriangularInv(on_error="nan")(x))