Skip to content

Rewrite Solve involving Kron #1557

@jessegrabowski

Description

@jessegrabowski

Description

Define $\text{vec}(X)$ as the ravel operator that works column-wise (that's the math convention, so in code it's x.reshape(-1, order="F").

Given $(A \otimes B)x = y$. If $y = \text{vec}(Y)$, then $x = \text{vec}(B^{-1}YA^{-T})$

This identity suggests a rewrite for solve Ops involving Kron:

import pytensor.tensor as pt
import pytensor
import numpy as np

A, B = pt.dmatrices('A', 'B')
y = pt.dvector('y')

n, m = A.shape[0], B.shape[0]

x1 = pt.linalg.solve(pt.linalg.kron(A, B), y)
x2 = pt.linalg.solve(A, 
                     pt.linalg.solve(B, y.reshape((n, m)).T).T).ravel()

fn1 = pytensor.function([A, B, y], x1)
fn2 = pytensor.function([A, B, y], x2)

# Show equivalence
rng = np.random.default_rng()
n = 50
a_val, b_val = rng.normal(size=(2, n, n))
y_val = rng.normal(size=(n * n))    

np.allclose(fn1(a_val, b_val, y_val), fn2(a_val, b_val, y_val)) # True

Timings:

%timeit fn1(a_val, b_val, y_val) # 83.9 ms ± 626 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit fn2(a_val, b_val, y_val) # 178 μs ± 8.41 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions